hynt commited on
Commit
6f024ab
·
1 Parent(s): ca5f3c8

update zipvoice demo

Browse files
app.py CHANGED
@@ -1,7 +1,94 @@
 
 
 
1
  import gradio as gr
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import os
3
+ from huggingface_hub import login
4
  import gradio as gr
5
+ from cached_path import cached_path
6
+ import tempfile
7
+ from vinorm import TTSnorm
8
+ from infer_zipvoice import model, tokenizer, feature_extractor, device
9
+ from utils import preprocess_ref_audio_text, save_spectrogram
10
 
11
+ # Retrieve token from secrets
12
+ hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
13
 
14
+ # Log in to Hugging Face
15
+ if hf_token:
16
+ login(token=hf_token)
17
+
18
+ def post_process(text):
19
+ text = " " + text + " "
20
+ text = text.replace(" . . ", " . ")
21
+ text = " " + text + " "
22
+ text = text.replace(" .. ", " . ")
23
+ text = " " + text + " "
24
+ text = text.replace(" , , ", " , ")
25
+ text = " " + text + " "
26
+ text = text.replace(" ,, ", " , ")
27
+ text = " " + text + " "
28
+ text = text.replace('"', "")
29
+ return " ".join(text.split())
30
+
31
+ @spaces.GPU
32
+ def infer_tts(ref_audio_orig: str, gen_text: str, speed: float = 1.0, request: gr.Request = None):
33
+
34
+ if not ref_audio_orig:
35
+ raise gr.Error("Please upload a sample audio file.")
36
+ if not gen_text.strip():
37
+ raise gr.Error("Please enter the text content to generate voice.")
38
+ if len(gen_text.split()) > 1000:
39
+ raise gr.Error("Please enter text content with less than 1000 words.")
40
+
41
+ try:
42
+ ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, "")
43
+ final_wave = generate_sentence(
44
+ ref_text.lower(),
45
+ ref_audio,
46
+ post_process(TTSnorm(gen_text)).lower(),
47
+ model=model,
48
+ vocoder=vocoder,
49
+ tokenizer=tokenizer,
50
+ feature_extractor=feature_extractor,
51
+ device=device,
52
+ speed=speed
53
+ )
54
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
55
+ spectrogram_path = tmp_spectrogram.name
56
+ save_spectrogram(spectrogram, spectrogram_path)
57
+
58
+ return (final_sample_rate, final_wave), spectrogram_path
59
+ except Exception as e:
60
+ raise gr.Error(f"Error generating voice: {e}")
61
+
62
+ # Gradio UI
63
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
64
+ gr.Markdown("""
65
+ # 🎤 ZipVoice: Vietnamese Text-to-Speech Synthesis.
66
+ # The model was trained with approximately 150 hours of data on a RTX 3090 GPU.
67
+ Enter text and upload a sample voice to generate natural speech.
68
+ """)
69
+
70
+ with gr.Row():
71
+ ref_audio = gr.Audio(label="🔊 Sample Voice", type="filepath")
72
+ gen_text = gr.Textbox(label="📝 Text", placeholder="Enter the text to generate voice...", lines=3)
73
+
74
+ speed = gr.Slider(0.3, 2.0, value=1.0, step=0.1, label="⚡ Speed")
75
+ btn_synthesize = gr.Button("🔥 Generate Voice")
76
+
77
+ with gr.Row():
78
+ output_audio = gr.Audio(label="🎧 Generated Audio", type="numpy")
79
+ output_spectrogram = gr.Image(label="📊 Spectrogram")
80
+
81
+ model_limitations = gr.Textbox(
82
+ value="""1. This model may not perform well with numerical characters, dates, special characters, etc. => A text normalization module is needed.
83
+ 2. The rhythm of some generated audios may be inconsistent or choppy => It is recommended to select clearly pronounced sample audios with minimal pauses for better synthesis quality.
84
+ 3. Default, reference audio text uses the pho-whisper-medium model, which may not always accurately recognize Vietnamese, resulting in poor voice synthesis quality.
85
+ 4. Inference with overly long paragraphs may produce poor results.""",
86
+ label="❗ Model Limitations",
87
+ lines=4,
88
+ interactive=False
89
+ )
90
+
91
+ btn_synthesize.click(infer_tts, inputs=[ref_audio, gen_text, speed], outputs=[output_audio, output_spectrogram])
92
+
93
+ # Run Gradio with share=True to get a gradio.live link
94
+ demo.queue().launch()
infer_zipvoice.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ This script generates speech with our pre-trained ZipVoice or
20
+ ZipVoice-Distill models. If no local model is specified,
21
+ Required files will be automatically downloaded from HuggingFace.
22
+
23
+ Usage:
24
+
25
+ Note: If you having trouble connecting to HuggingFace,
26
+ try switching endpoint to mirror site:
27
+ export HF_ENDPOINT=https://hf-mirror.com
28
+
29
+ (1) Inference of a single sentence:
30
+
31
+ python3 -m zipvoice.bin.infer_zipvoice \
32
+ --model-name "zipvoice" \
33
+ --prompt-wav prompt.wav \
34
+ --prompt-text "I am a prompt." \
35
+ --text "I am a sentence." \
36
+ --res-wav-path result.wav
37
+
38
+ (2) Inference of a list of sentences:
39
+
40
+ python3 -m zipvoice.bin.infer_zipvoice \
41
+ --model-name "zipvoice" \
42
+ --test-list test.tsv \
43
+ --res-dir results
44
+
45
+ `--model-name` can be `zipvoice` or `zipvoice_distill`,
46
+ which are the models before and after distillation, respectively.
47
+
48
+ Each line of `test.tsv` is in the format of
49
+ `{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}`.
50
+ """
51
+
52
+ import argparse
53
+ import datetime as dt
54
+ import json
55
+ import os
56
+ from typing import Optional
57
+
58
+ import numpy as np
59
+ import safetensors.torch
60
+ import torch
61
+ import torchaudio
62
+ from huggingface_hub import hf_hub_download
63
+ from lhotse.utils import fix_random_seed
64
+ from vocos import Vocos
65
+
66
+ from zipvoice.models.zipvoice import ZipVoice
67
+ from zipvoice.models.zipvoice_distill import ZipVoiceDistill
68
+ from zipvoice.tokenizer.tokenizer import (
69
+ EmiliaTokenizer,
70
+ EspeakTokenizer,
71
+ LibriTTSTokenizer,
72
+ SimpleTokenizer,
73
+ )
74
+ from zipvoice.utils.checkpoint import load_checkpoint
75
+ from zipvoice.utils.common import AttributeDict
76
+ from zipvoice.utils.feature import VocosFbank
77
+
78
+ HUGGINGFACE_REPO = "k2-fsa/ZipVoice"
79
+ PRETRAINED_MODEL = {
80
+ "zipvoice": "zipvoice/model.pt",
81
+ "zipvoice_distill": "zipvoice_distill/model.pt",
82
+ }
83
+ TOKEN_FILE = {
84
+ "zipvoice": "zipvoice/tokens.txt",
85
+ "zipvoice_distill": "zipvoice_distill/tokens.txt",
86
+ }
87
+ MODEL_CONFIG = {
88
+ "zipvoice": "zipvoice/zipvoice_base.json",
89
+ "zipvoice_distill": "zipvoice_distill/zipvoice_base.json",
90
+ }
91
+
92
+ torch.set_num_threads(1)
93
+ torch.set_num_interop_threads(1)
94
+
95
+ def get_vocoder(vocos_local_path: Optional[str] = None):
96
+ if vocos_local_path:
97
+ vocoder = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
98
+ state_dict = torch.load(
99
+ f"{vocos_local_path}/pytorch_model.bin",
100
+ weights_only=True,
101
+ map_location="cpu",
102
+ )
103
+ vocoder.load_state_dict(state_dict)
104
+ else:
105
+ vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz")
106
+ return vocoder
107
+
108
+
109
+ def generate_sentence(
110
+ prompt_text: str,
111
+ prompt_wav: str,
112
+ text: str,
113
+ model: torch.nn.Module,
114
+ vocoder: torch.nn.Module,
115
+ tokenizer: EmiliaTokenizer,
116
+ feature_extractor: VocosFbank,
117
+ device: torch.device,
118
+ num_step: int = 16,
119
+ guidance_scale: float = 1.0,
120
+ speed: float = 1.0,
121
+ t_shift: float = 0.5,
122
+ target_rms: float = 0.1,
123
+ feat_scale: float = 0.1,
124
+ sampling_rate: int = 24000,
125
+ ):
126
+ """
127
+ Generate waveform of a text based on a given prompt
128
+ waveform and its transcription.
129
+
130
+ Args:
131
+ save_path (str): Path to save the generated wav.
132
+ prompt_text (str): Transcription of the prompt wav.
133
+ prompt_wav (str): Path to the prompt wav file.
134
+ text (str): Text to be synthesized into a waveform.
135
+ model (torch.nn.Module): The model used for generation.
136
+ vocoder (torch.nn.Module): The vocoder used to convert features to waveforms.
137
+ tokenizer (EmiliaTokenizer): The tokenizer used to convert text to tokens.
138
+ feature_extractor (VocosFbank): The feature extractor used to
139
+ extract acoustic features.
140
+ device (torch.device): The device on which computations are performed.
141
+ num_step (int, optional): Number of steps for decoding. Defaults to 16.
142
+ guidance_scale (float, optional): Scale for classifier-free guidance.
143
+ Defaults to 1.0.
144
+ speed (float, optional): Speed control. Defaults to 1.0.
145
+ t_shift (float, optional): Time shift. Defaults to 0.5.
146
+ target_rms (float, optional): Target RMS for waveform normalization.
147
+ Defaults to 0.1.
148
+ feat_scale (float, optional): Scale for features.
149
+ Defaults to 0.1.
150
+ sampling_rate (int, optional): Sampling rate for the waveform.
151
+ Defaults to 24000.
152
+ Returns:
153
+ metrics (dict): Dictionary containing time and real-time
154
+ factor metrics for processing.
155
+ """
156
+ # Convert text to tokens
157
+ tokens = tokenizer.texts_to_token_ids([text])
158
+ prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
159
+
160
+ # Load and preprocess prompt wav
161
+ prompt_wav, prompt_sampling_rate = torchaudio.load(prompt_wav)
162
+
163
+ if prompt_sampling_rate != sampling_rate:
164
+ resampler = torchaudio.transforms.Resample(
165
+ orig_freq=prompt_sampling_rate, new_freq=sampling_rate
166
+ )
167
+ prompt_wav = resampler(prompt_wav)
168
+
169
+ prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav)))
170
+ if prompt_rms < target_rms:
171
+ prompt_wav = prompt_wav * target_rms / prompt_rms
172
+
173
+ # Extract features from prompt wav
174
+ prompt_features = feature_extractor.extract(
175
+ prompt_wav, sampling_rate=sampling_rate
176
+ ).to(device)
177
+
178
+ prompt_features = prompt_features.unsqueeze(0) * feat_scale
179
+ prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device)
180
+
181
+ # Start timing
182
+ start_t = dt.datetime.now()
183
+
184
+ # Generate features
185
+ (
186
+ pred_features,
187
+ pred_features_lens,
188
+ pred_prompt_features,
189
+ pred_prompt_features_lens,
190
+ ) = model.sample(
191
+ tokens=tokens,
192
+ prompt_tokens=prompt_tokens,
193
+ prompt_features=prompt_features,
194
+ prompt_features_lens=prompt_features_lens,
195
+ speed=speed,
196
+ t_shift=t_shift,
197
+ duration="predict",
198
+ num_step=num_step,
199
+ guidance_scale=guidance_scale,
200
+ )
201
+
202
+ # Postprocess predicted features
203
+ pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
204
+
205
+ # Start vocoder processing
206
+ start_vocoder_t = dt.datetime.now()
207
+ wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)
208
+
209
+ # Calculate processing times and real-time factors
210
+ t = (dt.datetime.now() - start_t).total_seconds()
211
+ t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
212
+ t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
213
+ wav_seconds = wav.shape[-1] / sampling_rate
214
+ rtf = t / wav_seconds
215
+ rtf_no_vocoder = t_no_vocoder / wav_seconds
216
+ rtf_vocoder = t_vocoder / wav_seconds
217
+ # metrics = {
218
+ # "t": t,
219
+ # "t_no_vocoder": t_no_vocoder,
220
+ # "t_vocoder": t_vocoder,
221
+ # "wav_seconds": wav_seconds,
222
+ # "rtf": rtf,
223
+ # "rtf_no_vocoder": rtf_no_vocoder,
224
+ # "rtf_vocoder": rtf_vocoder,
225
+ # }
226
+
227
+ # Adjust wav volume if necessary
228
+ if prompt_rms < target_rms:
229
+ wav = wav * prompt_rms / target_rms
230
+ # torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate)
231
+ # return metrics
232
+ return wav.cpu()
233
+
234
+ model_defaults = {
235
+ "zipvoice": {
236
+ "num_step": 16,
237
+ "guidance_scale": 1.0,
238
+ },
239
+ "zipvoice_distill": {
240
+ "num_step": 8,
241
+ "guidance_scale": 3.0,
242
+ },
243
+ }
244
+
245
+ device = torch.device("cuda", 0)
246
+
247
+ print("Loading model...")
248
+ model_config = "ckpt/model.json"
249
+
250
+ with open(model_config, "r") as f:
251
+ model_config = json.load(f)
252
+
253
+ token_file = "ckpt/tokens.txt"
254
+
255
+ tokenizer = EspeakTokenizer(token_file=token_file, lang="vi")
256
+
257
+ tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id}
258
+
259
+ model_ckpt = "ckpt/model.pt"
260
+
261
+ model = ZipVoice(
262
+ **model_config["model"],
263
+ **tokenizer_config,
264
+ )
265
+
266
+ load_checkpoint(filename=model_ckpt, model=model, strict=True)
267
+
268
+ model = model.to(device)
269
+ model.eval()
270
+
271
+ vocoder = get_vocoder(None)
272
+ vocoder = vocoder.to(device)
273
+ vocoder.eval()
274
+
275
+ if model_config["feature"]["type"] == "vocos":
276
+ feature_extractor = VocosFbank()
277
+ else:
278
+ raise NotImplementedError(
279
+ f"Unsupported feature type: {model_config['feature']['type']}"
280
+ )
281
+ sampling_rate = model_config["feature"]["sampling_rate"]
282
+
283
+ # generate_sentence(
284
+ # save_path=res_wav_path,
285
+ # prompt_text=prompt_text,
286
+ # prompt_wav=prompt_wav,
287
+ # text=text,
288
+ # model=model,
289
+ # vocoder=vocoder,
290
+ # tokenizer=tokenizer,
291
+ # feature_extractor=feature_extractor,
292
+ # device=device,
293
+ # num_step=16,
294
+ # guidance_scale=1.0,
295
+ # speed=speed,
296
+ # t_shift=0.5,
297
+ # target_rms=0.1,
298
+ # feat_scale=0.1,
299
+ # sampling_rate=sampling_rate,
300
+ # )
301
+
302
+ # print("Done")
requirements.txt ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --find-links https://k2-fsa.github.io/icefall/piper_phonemize.html
2
+
3
+ torch<=2.6.0
4
+ torchaudio<=2.6.0
5
+ lhotse
6
+ tensorboard
7
+ vocos
8
+
9
+ # Normalization
10
+ cn2an
11
+ inflect
12
+ unidecode
13
+
14
+ # Tokenization
15
+ piper_phonemize
16
+
17
+ k2==1.24.4.dev20250208+cuda12.4.torch2.5.1 --find-links https://k2-fsa.github.io/k2/cuda-cn.html
18
+
19
+ transformers
20
+ bitsandbytes>0.37.0
21
+ vinorm
22
+ cached_path
23
+ huggingface_hub
24
+ gradio
25
+ accelerate>=0.33.0
26
+ click
27
+ datasets
28
+ ema_pytorch>=0.5.2
29
+ gradio>=3.45.2
30
+ hydra-core>=1.3.0
31
+ jieba
32
+ librosa
33
+ matplotlib
34
+ numpy<=1.26.4
35
+ pydub
36
+ pypinyin
37
+ safetensors
38
+ soundfile
39
+ tomli
40
+ torchdiffeq
41
+ tqdm>=4.65.0
42
+ transformers_stream_generator
43
+ wandb
44
+ x_transformers>=1.31.14
utils.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydub import AudioSegment, silence
2
+ import tempfile
3
+ import hashlib
4
+ import matplotlib.pylab as plt
5
+ import librosa
6
+ from transformers import pipeline
7
+
8
+ def initialize_asr_pipeline(device: str = device, dtype=None):
9
+ if dtype is None:
10
+ dtype = (
11
+ torch.float16
12
+ if "cuda" in device
13
+ and torch.cuda.get_device_properties(device).major >= 6
14
+ and not torch.cuda.get_device_name().endswith("[ZLUDA]")
15
+ else torch.float32
16
+ )
17
+ global asr_pipe
18
+ asr_pipe = pipeline(
19
+ "automatic-speech-recognition",
20
+ model="vinai/PhoWhisper-medium",
21
+ torch_dtype=dtype,
22
+ device=device,
23
+ )
24
+
25
+ # transcribe
26
+ def transcribe(ref_audio, language=None):
27
+ global asr_pipe
28
+ if asr_pipe is None:
29
+ initialize_asr_pipeline(device=device)
30
+ return asr_pipe(
31
+ ref_audio,
32
+ chunk_length_s=30,
33
+ batch_size=128,
34
+ generate_kwargs={"task": "transcribe", "language": language} if language else {"task": "transcribe"},
35
+ return_timestamps=False,
36
+ )["text"].strip()
37
+
38
+ def caculate_spec(audio):
39
+ # Compute spectrogram (Short-Time Fourier Transform)
40
+ stft = librosa.stft(audio, n_fft=512, hop_length=256, win_length=512)
41
+ spectrogram = np.abs(stft)
42
+ # Convert to dB
43
+ spectrogram_db = librosa.amplitude_to_db(spectrogram, ref=np.max)
44
+ return spectrogram_db
45
+
46
+ def save_spectrogram(audio, path):
47
+ spectrogram = caculate_spec(audio)
48
+ plt.figure(figsize=(12, 4))
49
+ plt.imshow(spectrogram, origin="lower", aspect="auto")
50
+ plt.colorbar()
51
+ plt.savefig(path)
52
+ plt.close()
53
+
54
+ def preprocess_ref_audio_text(ref_audio_orig, ref_text, clip_short=True, show_info=print, device=device):
55
+
56
+ show_info("Converting audio...")
57
+
58
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
59
+
60
+ aseg = AudioSegment.from_file(ref_audio_orig)
61
+
62
+ if clip_short:
63
+ # 1. try to find long silence for clipping
64
+ non_silent_segs = silence.split_on_silence(
65
+ aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000, seek_step=10
66
+ )
67
+ non_silent_wave = AudioSegment.silent(duration=0)
68
+ for non_silent_seg in non_silent_segs:
69
+ if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000:
70
+ show_info("Audio is over 15s, clipping short. (1)")
71
+ break
72
+ non_silent_wave += non_silent_seg
73
+
74
+ # 2. try to find short silence for clipping if 1. failed
75
+ if len(non_silent_wave) > 15000:
76
+ non_silent_segs = silence.split_on_silence(
77
+ aseg, min_silence_len=100, silence_thresh=-40, keep_silence=1000, seek_step=10
78
+ )
79
+ non_silent_wave = AudioSegment.silent(duration=0)
80
+ for non_silent_seg in non_silent_segs:
81
+ if len(non_silent_wave) > 6000 and len(non_silent_wave + non_silent_seg) > 15000:
82
+ show_info("Audio is over 15s, clipping short. (2)")
83
+ break
84
+ non_silent_wave += non_silent_seg
85
+
86
+ aseg = non_silent_wave
87
+
88
+ # 3. if no proper silence found for clipping
89
+ if len(aseg) > 15000:
90
+ aseg = aseg[:15000]
91
+ show_info("Audio is over 15s, clipping short. (3)")
92
+
93
+ aseg = remove_silence_edges(aseg) + AudioSegment.silent(duration=50)
94
+ aseg.export(f.name, format="wav")
95
+ ref_audio = f.name
96
+
97
+ # Compute a hash of the reference audio file
98
+ with open(ref_audio, "rb") as audio_file:
99
+ audio_data = audio_file.read()
100
+ audio_hash = hashlib.md5(audio_data).hexdigest()
101
+
102
+ if not ref_text.strip():
103
+ global _ref_audio_cache
104
+ if audio_hash in _ref_audio_cache:
105
+ # Use cached asr transcription
106
+ show_info("Using cached reference text...")
107
+ ref_text = _ref_audio_cache[audio_hash]
108
+ else:
109
+ show_info("No reference text provided, transcribing reference audio...")
110
+ ref_text = transcribe(ref_audio)
111
+ # Cache the transcribed text (not caching custom ref_text, enabling users to do manual tweak)
112
+ _ref_audio_cache[audio_hash] = ref_text
113
+ else:
114
+ show_info("Using custom reference text...")
115
+
116
+ # Ensure ref_text ends with a proper sentence-ending punctuation
117
+ if not ref_text.endswith(". ") and not ref_text.endswith("。"):
118
+ if ref_text.endswith("."):
119
+ ref_text += " "
120
+ else:
121
+ ref_text += ". "
122
+
123
+ print("\nref_text ", ref_text)
124
+
125
+ return ref_audio, ref_text
zipvoice/__init__.py ADDED
File without changes
zipvoice/bin/compute_fbank.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2024-2025 Xiaomi Corp. (authors: Wei Kang
3
+ # Han Zhu)
4
+ #
5
+ # See ../../../../LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+ """
19
+ Usage:
20
+ python3 -m zipvoice.bin.compute_fbank \
21
+ --source-dir data/manifests \
22
+ --dest-dir data/fbank \
23
+ --dataset libritts \
24
+ --subset dev-other \
25
+ --sampling-rate 24000 \
26
+ --num-jobs 20
27
+
28
+ The input would be data/manifests/libritts-cuts_dev-other.jsonl.gz or
29
+ (libritts_supervisions_dev-other.jsonl.gz and librittsrecordings_dev-other.jsonl.gz)
30
+
31
+ The output would be data/fbank/libritts-cuts_dev-other.jsonl.gz
32
+ """
33
+
34
+
35
+ import argparse
36
+ import logging
37
+ from concurrent.futures import ProcessPoolExecutor as Pool
38
+ from pathlib import Path
39
+
40
+ import lhotse
41
+ import torch
42
+ from lhotse import CutSet, LilcomChunkyWriter, load_manifest_lazy
43
+
44
+ from zipvoice.utils.feature import VocosFbank
45
+
46
+ # Torch's multithreaded behavior needs to be disabled or
47
+ # it wastes a lot of CPU and slow things down.
48
+ # Do this outside of main() in case it needs to take effect
49
+ # even when we are not invoking the main (e.g. when spawning subprocesses).
50
+ torch.set_num_threads(1)
51
+ torch.set_num_interop_threads(1)
52
+
53
+
54
+ def str2bool(v):
55
+ """Used in argparse.ArgumentParser.add_argument to indicate
56
+ that a type is a bool type and user can enter
57
+
58
+ - yes, true, t, y, 1, to represent True
59
+ - no, false, f, n, 0, to represent False
60
+
61
+ See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa
62
+ """
63
+ if isinstance(v, bool):
64
+ return v
65
+ if v.lower() in ("yes", "true", "t", "y", "1"):
66
+ return True
67
+ elif v.lower() in ("no", "false", "f", "n", "0"):
68
+ return False
69
+ else:
70
+ raise argparse.ArgumentTypeError("Boolean value expected.")
71
+
72
+
73
+ def get_args():
74
+ parser = argparse.ArgumentParser()
75
+
76
+ parser.add_argument(
77
+ "--sampling-rate",
78
+ type=int,
79
+ default=24000,
80
+ help="The target sampling rate, the audio will be resampled to it.",
81
+ )
82
+
83
+ parser.add_argument(
84
+ "--type",
85
+ type=str,
86
+ default="vocos",
87
+ help="fbank type",
88
+ )
89
+
90
+ parser.add_argument(
91
+ "--dataset",
92
+ type=str,
93
+ help="Dataset name.",
94
+ )
95
+
96
+ parser.add_argument(
97
+ "--subset",
98
+ type=str,
99
+ help="The subset of the dataset.",
100
+ )
101
+
102
+ parser.add_argument(
103
+ "--source-dir",
104
+ type=str,
105
+ default="data/manifests",
106
+ help="The source directory of manifest files.",
107
+ )
108
+
109
+ parser.add_argument(
110
+ "--dest-dir",
111
+ type=str,
112
+ default="data/fbank",
113
+ help="The destination directory of manifest files.",
114
+ )
115
+
116
+ parser.add_argument(
117
+ "--split-cuts",
118
+ type=str2bool,
119
+ default=False,
120
+ help="Whether to use splited cuts.",
121
+ )
122
+
123
+ parser.add_argument(
124
+ "--split-begin",
125
+ type=int,
126
+ help="Start idx of splited cuts.",
127
+ )
128
+
129
+ parser.add_argument(
130
+ "--split-end",
131
+ type=int,
132
+ help="End idx of splited cuts.",
133
+ )
134
+
135
+ parser.add_argument(
136
+ "--batch-duration",
137
+ type=int,
138
+ default=1000,
139
+ help="The batch duration when computing the features.",
140
+ )
141
+
142
+ parser.add_argument(
143
+ "--num-jobs",
144
+ type=int,
145
+ default=20,
146
+ help="The number of extractor workers.",
147
+ )
148
+
149
+ return parser.parse_args()
150
+
151
+
152
+ def compute_fbank_split_single(params, idx):
153
+ lhotse.set_audio_duration_mismatch_tolerance(0.1) # for emilia
154
+ src_dir = Path(params.source_dir)
155
+ output_dir = Path(params.dest_dir)
156
+
157
+ if not src_dir.exists():
158
+ logging.error(f"{src_dir} not exists")
159
+ return
160
+
161
+ if not output_dir.exists():
162
+ output_dir.mkdir(parents=True, exist_ok=True)
163
+
164
+ num_digits = 8
165
+ if params.type == "vocos":
166
+ extractor = VocosFbank()
167
+ else:
168
+ raise NotImplementedError(f"{params.type} is not supported")
169
+
170
+ prefix = params.dataset
171
+ subset = params.subset
172
+ suffix = "jsonl.gz"
173
+
174
+ idx = f"{idx}".zfill(num_digits)
175
+ cuts_filename = f"{prefix}_cuts_{subset}.{idx}.{suffix}"
176
+
177
+ if (src_dir / cuts_filename).is_file():
178
+ logging.info(f"Loading manifests {src_dir / cuts_filename}")
179
+ cut_set = load_manifest_lazy(src_dir / cuts_filename)
180
+ else:
181
+ logging.warning(f"Raw {cuts_filename} not exists, skipping")
182
+ return
183
+
184
+ cut_set = cut_set.resample(params.sampling_rate)
185
+
186
+ if (output_dir / cuts_filename).is_file():
187
+ logging.info(f"{cuts_filename} already exists - skipping.")
188
+ return
189
+
190
+ logging.info(f"Processing {subset}.{idx} of {prefix}")
191
+
192
+ cut_set = cut_set.compute_and_store_features_batch(
193
+ extractor=extractor,
194
+ storage_path=f"{output_dir}/{prefix}_feats_{subset}_{idx}",
195
+ num_workers=4,
196
+ batch_duration=params.batch_duration,
197
+ storage_type=LilcomChunkyWriter,
198
+ overwrite=True,
199
+ )
200
+ cut_set.to_file(output_dir / cuts_filename)
201
+
202
+
203
+ def compute_fbank_split(params):
204
+ if params.split_end < params.split_begin:
205
+ logging.warning(
206
+ f"Split begin should be smaller than split end, given "
207
+ f"{params.split_begin} -> {params.split_end}."
208
+ )
209
+
210
+ with Pool(max_workers=params.num_jobs) as pool:
211
+ futures = [
212
+ pool.submit(compute_fbank_split_single, params, i)
213
+ for i in range(params.split_begin, params.split_end)
214
+ ]
215
+ for f in futures:
216
+ f.result()
217
+ f.done()
218
+
219
+
220
+ def compute_fbank(params):
221
+ src_dir = Path(params.source_dir)
222
+ output_dir = Path(params.dest_dir)
223
+ num_jobs = params.num_jobs
224
+ if not output_dir.exists():
225
+ output_dir.mkdir(parents=True, exist_ok=True)
226
+
227
+ prefix = params.dataset
228
+ subset = params.subset
229
+ suffix = "jsonl.gz"
230
+
231
+ cut_set_name = f"{prefix}_cuts_{subset}.{suffix}"
232
+
233
+ if (src_dir / cut_set_name).is_file():
234
+ logging.info(f"Loading manifests {src_dir / cut_set_name}")
235
+ cut_set = load_manifest_lazy(src_dir / cut_set_name)
236
+ else:
237
+ recordings = load_manifest_lazy(
238
+ src_dir / f"{prefix}_recordings_{subset}.{suffix}"
239
+ )
240
+ supervisions = load_manifest_lazy(
241
+ src_dir / f"{prefix}_supervisions_{subset}.{suffix}"
242
+ )
243
+ cut_set = CutSet.from_manifests(
244
+ recordings=recordings,
245
+ supervisions=supervisions,
246
+ )
247
+
248
+ cut_set = cut_set.resample(params.sampling_rate)
249
+ if params.type == "vocos":
250
+ extractor = VocosFbank()
251
+ else:
252
+ raise NotImplementedError(f"{params.type} is not supported")
253
+
254
+ cuts_filename = f"{prefix}_cuts_{subset}.{suffix}"
255
+ if (output_dir / cuts_filename).is_file():
256
+ logging.info(f"{prefix} {subset} already exists - skipping.")
257
+ return
258
+ logging.info(f"Processing {subset} of {prefix}")
259
+
260
+ cut_set = cut_set.compute_and_store_features(
261
+ extractor=extractor,
262
+ storage_path=f"{output_dir}/{prefix}_feats_{subset}",
263
+ num_jobs=num_jobs,
264
+ storage_type=LilcomChunkyWriter,
265
+ )
266
+ cut_set.to_file(output_dir / cuts_filename)
267
+
268
+
269
+ if __name__ == "__main__":
270
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
271
+
272
+ logging.basicConfig(format=formatter, level=logging.INFO)
273
+ args = get_args()
274
+ logging.info(vars(args))
275
+ if args.split_cuts:
276
+ compute_fbank_split(params=args)
277
+ else:
278
+ compute_fbank(params=args)
zipvoice/bin/generate_averaged_model.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ #
3
+ # Copyright 2021-2022 Xiaomi Corporation
4
+ #
5
+ # See ../../../../LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+ """
19
+ Usage:
20
+ This script loads checkpoints and averages them.
21
+
22
+ python3 -m zipvoice.bin.generate_averaged_model \
23
+ --epoch 11 \
24
+ --avg 4 \
25
+ --model_name zipvoice \
26
+ --model-config conf/zipvoice_base.json \
27
+ --token-file data/tokens_emilia.txt \
28
+ --exp-dir exp/zipvoice
29
+
30
+ It will generate a file `epoch-11-avg-14.pt` in the given `exp_dir`.
31
+ You can later load it by `torch.load("epoch-11-avg-4.pt")`.
32
+ """
33
+
34
+ import argparse
35
+ import json
36
+ from pathlib import Path
37
+
38
+ import torch
39
+
40
+ from zipvoice.models.zipvoice import ZipVoice
41
+ from zipvoice.models.zipvoice_dialog import ZipVoiceDialog, ZipVoiceDialogStereo
42
+ from zipvoice.models.zipvoice_distill import ZipVoiceDistill
43
+ from zipvoice.tokenizer.tokenizer import SimpleTokenizer
44
+ from zipvoice.utils.checkpoint import (
45
+ average_checkpoints_with_averaged_model,
46
+ find_checkpoints,
47
+ )
48
+ from zipvoice.utils.common import AttributeDict
49
+
50
+
51
+ def get_parser():
52
+ parser = argparse.ArgumentParser(
53
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
54
+ )
55
+
56
+ parser.add_argument(
57
+ "--epoch",
58
+ type=int,
59
+ default=11,
60
+ help="""It specifies the checkpoint to use for decoding.
61
+ Note: Epoch counts from 1.
62
+ You can specify --avg to use more checkpoints for model averaging.""",
63
+ )
64
+
65
+ parser.add_argument(
66
+ "--iter",
67
+ type=int,
68
+ default=0,
69
+ help="""If positive, --epoch is ignored and it
70
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
71
+ You can specify --avg to use more checkpoints for model averaging.
72
+ """,
73
+ )
74
+
75
+ parser.add_argument(
76
+ "--avg",
77
+ type=int,
78
+ default=4,
79
+ help="Number of checkpoints to average. Automatically select "
80
+ "consecutive checkpoints before the checkpoint specified by "
81
+ "'--epoch' or --iter",
82
+ )
83
+
84
+ parser.add_argument(
85
+ "--exp-dir",
86
+ type=str,
87
+ default="zipvoice/exp_zipvoice",
88
+ help="The experiment dir",
89
+ )
90
+
91
+ parser.add_argument(
92
+ "--model_name",
93
+ type=str,
94
+ default="zipvoice",
95
+ choices=[
96
+ "zipvoice",
97
+ "zipvoice_distill",
98
+ "zipvoice_dialog",
99
+ "zipvoice_dialog_stereo",
100
+ ],
101
+ help="The model type to be averaged. ",
102
+ )
103
+
104
+ parser.add_argument(
105
+ "--model-config",
106
+ type=str,
107
+ default="conf/zipvoice_base.json",
108
+ help="The model configuration file.",
109
+ )
110
+
111
+ parser.add_argument(
112
+ "--token-file",
113
+ type=str,
114
+ default="data/tokens_emilia.txt",
115
+ help="The file that contains information that maps tokens to ids,"
116
+ "which is a text file with '{token}\t{token_id}' per line if type is"
117
+ "char or phone, otherwise it is a bpe_model file.",
118
+ )
119
+
120
+ return parser
121
+
122
+
123
+ @torch.no_grad()
124
+ def main():
125
+ parser = get_parser()
126
+ args = parser.parse_args()
127
+ args.exp_dir = Path(args.exp_dir)
128
+ params = AttributeDict()
129
+ params.update(vars(args))
130
+
131
+ with open(params.model_config, "r") as f:
132
+ model_config = json.load(f)
133
+
134
+ tokenizer = SimpleTokenizer(token_file=params.token_file)
135
+ if params.model_name in ["zipvoice", "zipvoice_distill"]:
136
+ tokenizer_config = {
137
+ "vocab_size": tokenizer.vocab_size,
138
+ "pad_id": tokenizer.pad_id,
139
+ }
140
+ elif params.model_name in ["zipvoice_dialog", "zipvoice_dialog_stereo"]:
141
+ tokenizer_config = {
142
+ "vocab_size": tokenizer.vocab_size,
143
+ "pad_id": tokenizer.pad_id,
144
+ "spk_a_id": tokenizer.spk_a_id,
145
+ "spk_b_id": tokenizer.spk_a_id,
146
+ }
147
+
148
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
149
+
150
+ print("Script started")
151
+
152
+ params.device = torch.device("cpu")
153
+ print(f"Device: {params.device}")
154
+
155
+ print("About to create model")
156
+ if params.model_name == "zipvoice":
157
+ model = ZipVoice(
158
+ **model_config["model"],
159
+ **tokenizer_config,
160
+ )
161
+ elif params.model_name == "zipvoice_distill":
162
+ model = ZipVoiceDistill(
163
+ **model_config["model"],
164
+ **tokenizer_config,
165
+ )
166
+ elif params.model_name == "zipvoice_dialog":
167
+ model = ZipVoiceDialog(
168
+ **model_config["model"],
169
+ **tokenizer_config,
170
+ )
171
+ elif params.model_name == "zipvoice_dialog_stereo":
172
+ model = ZipVoiceDialogStereo(
173
+ **model_config["model"],
174
+ **tokenizer_config,
175
+ )
176
+ else:
177
+ raise ValueError(f"Unknown model name: {params.model_name}")
178
+
179
+ if params.iter > 0:
180
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
181
+ : params.avg + 1
182
+ ]
183
+ if len(filenames) == 0:
184
+ raise ValueError(
185
+ f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}"
186
+ )
187
+ elif len(filenames) < params.avg + 1:
188
+ raise ValueError(
189
+ f"Not enough checkpoints ({len(filenames)}) found for"
190
+ f" --iter {params.iter}, --avg {params.avg}"
191
+ )
192
+ filename_start = filenames[-1]
193
+ filename_end = filenames[0]
194
+ print(
195
+ "Calculating the averaged model over iteration checkpoints"
196
+ f" from {filename_start} (excluded) to {filename_end}"
197
+ )
198
+ model.to(params.device)
199
+ model.load_state_dict(
200
+ average_checkpoints_with_averaged_model(
201
+ filename_start=filename_start,
202
+ filename_end=filename_end,
203
+ device=params.device,
204
+ ),
205
+ strict=True,
206
+ )
207
+ else:
208
+ assert params.avg > 0, params.avg
209
+ start = params.epoch - params.avg
210
+ assert start >= 1, start
211
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
212
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
213
+ print(
214
+ f"Calculating the averaged model over epoch range from "
215
+ f"{start} (excluded) to {params.epoch}"
216
+ )
217
+ model.to(params.device)
218
+ model.load_state_dict(
219
+ average_checkpoints_with_averaged_model(
220
+ filename_start=filename_start,
221
+ filename_end=filename_end,
222
+ device=params.device,
223
+ ),
224
+ strict=True,
225
+ )
226
+ if params.iter > 0:
227
+ filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt"
228
+ else:
229
+ filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt"
230
+ torch.save({"model": model.state_dict()}, filename)
231
+
232
+ num_param = sum([p.numel() for p in model.parameters()])
233
+ print(f"Number of model parameters: {num_param}")
234
+
235
+ print("Done!")
236
+
237
+
238
+ if __name__ == "__main__":
239
+ main()
zipvoice/bin/infer_zipvoice.py ADDED
@@ -0,0 +1,617 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ This script generates speech with our pre-trained ZipVoice or
20
+ ZipVoice-Distill models. If no local model is specified,
21
+ Required files will be automatically downloaded from HuggingFace.
22
+
23
+ Usage:
24
+
25
+ Note: If you having trouble connecting to HuggingFace,
26
+ try switching endpoint to mirror site:
27
+ export HF_ENDPOINT=https://hf-mirror.com
28
+
29
+ (1) Inference of a single sentence:
30
+
31
+ python3 -m zipvoice.bin.infer_zipvoice \
32
+ --model-name "zipvoice" \
33
+ --prompt-wav prompt.wav \
34
+ --prompt-text "I am a prompt." \
35
+ --text "I am a sentence." \
36
+ --res-wav-path result.wav
37
+
38
+ (2) Inference of a list of sentences:
39
+
40
+ python3 -m zipvoice.bin.infer_zipvoice \
41
+ --model-name "zipvoice" \
42
+ --test-list test.tsv \
43
+ --res-dir results
44
+
45
+ `--model-name` can be `zipvoice` or `zipvoice_distill`,
46
+ which are the models before and after distillation, respectively.
47
+
48
+ Each line of `test.tsv` is in the format of
49
+ `{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}`.
50
+ """
51
+
52
+ import argparse
53
+ import datetime as dt
54
+ import json
55
+ import os
56
+ from typing import Optional
57
+
58
+ import numpy as np
59
+ import safetensors.torch
60
+ import torch
61
+ import torchaudio
62
+ from huggingface_hub import hf_hub_download
63
+ from lhotse.utils import fix_random_seed
64
+ from vocos import Vocos
65
+
66
+ from zipvoice.models.zipvoice import ZipVoice
67
+ from zipvoice.models.zipvoice_distill import ZipVoiceDistill
68
+ from zipvoice.tokenizer.tokenizer import (
69
+ EmiliaTokenizer,
70
+ EspeakTokenizer,
71
+ LibriTTSTokenizer,
72
+ SimpleTokenizer,
73
+ )
74
+ from zipvoice.utils.checkpoint import load_checkpoint
75
+ from zipvoice.utils.common import AttributeDict
76
+ from zipvoice.utils.feature import VocosFbank
77
+
78
+ HUGGINGFACE_REPO = "k2-fsa/ZipVoice"
79
+ PRETRAINED_MODEL = {
80
+ "zipvoice": "zipvoice/model.pt",
81
+ "zipvoice_distill": "zipvoice_distill/model.pt",
82
+ }
83
+ TOKEN_FILE = {
84
+ "zipvoice": "zipvoice/tokens.txt",
85
+ "zipvoice_distill": "zipvoice_distill/tokens.txt",
86
+ }
87
+ MODEL_CONFIG = {
88
+ "zipvoice": "zipvoice/zipvoice_base.json",
89
+ "zipvoice_distill": "zipvoice_distill/zipvoice_base.json",
90
+ }
91
+
92
+
93
+ def get_parser():
94
+ parser = argparse.ArgumentParser(
95
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
96
+ )
97
+
98
+ parser.add_argument(
99
+ "--model-name",
100
+ type=str,
101
+ default="zipvoice",
102
+ choices=["zipvoice", "zipvoice_distill"],
103
+ help="The model used for inference",
104
+ )
105
+
106
+ parser.add_argument(
107
+ "--checkpoint",
108
+ type=str,
109
+ default=None,
110
+ help="The model checkpoint. "
111
+ "Will download pre-trained checkpoint from huggingface if not specified.",
112
+ )
113
+
114
+ parser.add_argument(
115
+ "--model-config",
116
+ type=str,
117
+ default=None,
118
+ help="The model configuration file. "
119
+ "Will download zipvoice_base.json from huggingface if not specified.",
120
+ )
121
+
122
+ parser.add_argument(
123
+ "--vocoder-path",
124
+ type=str,
125
+ default=None,
126
+ help="The vocoder checkpoint. "
127
+ "Will download pre-trained vocoder from huggingface if not specified.",
128
+ )
129
+
130
+ parser.add_argument(
131
+ "--token-file",
132
+ type=str,
133
+ default=None,
134
+ help="The file that contains information that maps tokens to ids,"
135
+ "which is a text file with '{token}\t{token_id}' per line. "
136
+ "Will download tokens_emilia.txt from huggingface if not specified.",
137
+ )
138
+
139
+ parser.add_argument(
140
+ "--tokenizer",
141
+ type=str,
142
+ default="emilia",
143
+ choices=["emilia", "libritts", "espeak", "simple"],
144
+ help="Tokenizer type.",
145
+ )
146
+
147
+ parser.add_argument(
148
+ "--lang",
149
+ type=str,
150
+ default="en-us",
151
+ help="Language identifier, used when tokenizer type is espeak. see"
152
+ "https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md",
153
+ )
154
+
155
+ parser.add_argument(
156
+ "--test-list",
157
+ type=str,
158
+ default=None,
159
+ help="The list of prompt speech, prompt_transcription, "
160
+ "and text to synthesizein the format of "
161
+ "'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}'.",
162
+ )
163
+
164
+ parser.add_argument(
165
+ "--prompt-wav",
166
+ type=str,
167
+ default=None,
168
+ help="The prompt wav to mimic",
169
+ )
170
+
171
+ parser.add_argument(
172
+ "--prompt-text",
173
+ type=str,
174
+ default=None,
175
+ help="The transcription of the prompt wav",
176
+ )
177
+
178
+ parser.add_argument(
179
+ "--text",
180
+ type=str,
181
+ default=None,
182
+ help="The text to synthesize",
183
+ )
184
+
185
+ parser.add_argument(
186
+ "--res-dir",
187
+ type=str,
188
+ default="results",
189
+ help="""
190
+ Path name of the generated wavs dir,
191
+ used when test-list is not None
192
+ """,
193
+ )
194
+
195
+ parser.add_argument(
196
+ "--res-wav-path",
197
+ type=str,
198
+ default="result.wav",
199
+ help="""
200
+ Path name of the generated wav path,
201
+ used when test-list is None
202
+ """,
203
+ )
204
+
205
+ parser.add_argument(
206
+ "--guidance-scale",
207
+ type=float,
208
+ default=None,
209
+ help="The scale of classifier-free guidance during inference.",
210
+ )
211
+
212
+ parser.add_argument(
213
+ "--num-step",
214
+ type=int,
215
+ default=None,
216
+ help="The number of sampling steps.",
217
+ )
218
+
219
+ parser.add_argument(
220
+ "--feat-scale",
221
+ type=float,
222
+ default=0.1,
223
+ help="The scale factor of fbank feature",
224
+ )
225
+
226
+ parser.add_argument(
227
+ "--speed",
228
+ type=float,
229
+ default=1.0,
230
+ help="Control speech speed, 1.0 means normal, >1.0 means speed up",
231
+ )
232
+
233
+ parser.add_argument(
234
+ "--t-shift",
235
+ type=float,
236
+ default=0.5,
237
+ help="Shift t to smaller ones if t_shift < 1.0",
238
+ )
239
+
240
+ parser.add_argument(
241
+ "--target-rms",
242
+ type=float,
243
+ default=0.1,
244
+ help="Target speech normalization rms value, set to 0 to disable normalization",
245
+ )
246
+
247
+ parser.add_argument(
248
+ "--seed",
249
+ type=int,
250
+ default=666,
251
+ help="Random seed",
252
+ )
253
+
254
+ return parser
255
+
256
+
257
+ def get_vocoder(vocos_local_path: Optional[str] = None):
258
+ if vocos_local_path:
259
+ vocoder = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
260
+ state_dict = torch.load(
261
+ f"{vocos_local_path}/pytorch_model.bin",
262
+ weights_only=True,
263
+ map_location="cpu",
264
+ )
265
+ vocoder.load_state_dict(state_dict)
266
+ else:
267
+ vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz")
268
+ return vocoder
269
+
270
+
271
+ def generate_sentence(
272
+ save_path: str,
273
+ prompt_text: str,
274
+ prompt_wav: str,
275
+ text: str,
276
+ model: torch.nn.Module,
277
+ vocoder: torch.nn.Module,
278
+ tokenizer: EmiliaTokenizer,
279
+ feature_extractor: VocosFbank,
280
+ device: torch.device,
281
+ num_step: int = 16,
282
+ guidance_scale: float = 1.0,
283
+ speed: float = 1.0,
284
+ t_shift: float = 0.5,
285
+ target_rms: float = 0.1,
286
+ feat_scale: float = 0.1,
287
+ sampling_rate: int = 24000,
288
+ ):
289
+ """
290
+ Generate waveform of a text based on a given prompt
291
+ waveform and its transcription.
292
+
293
+ Args:
294
+ save_path (str): Path to save the generated wav.
295
+ prompt_text (str): Transcription of the prompt wav.
296
+ prompt_wav (str): Path to the prompt wav file.
297
+ text (str): Text to be synthesized into a waveform.
298
+ model (torch.nn.Module): The model used for generation.
299
+ vocoder (torch.nn.Module): The vocoder used to convert features to waveforms.
300
+ tokenizer (EmiliaTokenizer): The tokenizer used to convert text to tokens.
301
+ feature_extractor (VocosFbank): The feature extractor used to
302
+ extract acoustic features.
303
+ device (torch.device): The device on which computations are performed.
304
+ num_step (int, optional): Number of steps for decoding. Defaults to 16.
305
+ guidance_scale (float, optional): Scale for classifier-free guidance.
306
+ Defaults to 1.0.
307
+ speed (float, optional): Speed control. Defaults to 1.0.
308
+ t_shift (float, optional): Time shift. Defaults to 0.5.
309
+ target_rms (float, optional): Target RMS for waveform normalization.
310
+ Defaults to 0.1.
311
+ feat_scale (float, optional): Scale for features.
312
+ Defaults to 0.1.
313
+ sampling_rate (int, optional): Sampling rate for the waveform.
314
+ Defaults to 24000.
315
+ Returns:
316
+ metrics (dict): Dictionary containing time and real-time
317
+ factor metrics for processing.
318
+ """
319
+ # Convert text to tokens
320
+ tokens = tokenizer.texts_to_token_ids([text])
321
+ prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
322
+
323
+ # Load and preprocess prompt wav
324
+ prompt_wav, prompt_sampling_rate = torchaudio.load(prompt_wav)
325
+
326
+ if prompt_sampling_rate != sampling_rate:
327
+ resampler = torchaudio.transforms.Resample(
328
+ orig_freq=prompt_sampling_rate, new_freq=sampling_rate
329
+ )
330
+ prompt_wav = resampler(prompt_wav)
331
+
332
+ prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav)))
333
+ if prompt_rms < target_rms:
334
+ prompt_wav = prompt_wav * target_rms / prompt_rms
335
+
336
+ # Extract features from prompt wav
337
+ prompt_features = feature_extractor.extract(
338
+ prompt_wav, sampling_rate=sampling_rate
339
+ ).to(device)
340
+
341
+ prompt_features = prompt_features.unsqueeze(0) * feat_scale
342
+ prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device)
343
+
344
+ # Start timing
345
+ start_t = dt.datetime.now()
346
+
347
+ # Generate features
348
+ (
349
+ pred_features,
350
+ pred_features_lens,
351
+ pred_prompt_features,
352
+ pred_prompt_features_lens,
353
+ ) = model.sample(
354
+ tokens=tokens,
355
+ prompt_tokens=prompt_tokens,
356
+ prompt_features=prompt_features,
357
+ prompt_features_lens=prompt_features_lens,
358
+ speed=speed,
359
+ t_shift=t_shift,
360
+ duration="predict",
361
+ num_step=num_step,
362
+ guidance_scale=guidance_scale,
363
+ )
364
+
365
+ # Postprocess predicted features
366
+ pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
367
+
368
+ # Start vocoder processing
369
+ start_vocoder_t = dt.datetime.now()
370
+ wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)
371
+
372
+ # Calculate processing times and real-time factors
373
+ t = (dt.datetime.now() - start_t).total_seconds()
374
+ t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
375
+ t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
376
+ wav_seconds = wav.shape[-1] / sampling_rate
377
+ rtf = t / wav_seconds
378
+ rtf_no_vocoder = t_no_vocoder / wav_seconds
379
+ rtf_vocoder = t_vocoder / wav_seconds
380
+ metrics = {
381
+ "t": t,
382
+ "t_no_vocoder": t_no_vocoder,
383
+ "t_vocoder": t_vocoder,
384
+ "wav_seconds": wav_seconds,
385
+ "rtf": rtf,
386
+ "rtf_no_vocoder": rtf_no_vocoder,
387
+ "rtf_vocoder": rtf_vocoder,
388
+ }
389
+
390
+ # Adjust wav volume if necessary
391
+ if prompt_rms < target_rms:
392
+ wav = wav * prompt_rms / target_rms
393
+ torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate)
394
+
395
+ return metrics
396
+
397
+
398
+ def generate_list(
399
+ res_dir: str,
400
+ test_list: str,
401
+ model: torch.nn.Module,
402
+ vocoder: torch.nn.Module,
403
+ tokenizer: EmiliaTokenizer,
404
+ feature_extractor: VocosFbank,
405
+ device: torch.device,
406
+ num_step: int = 16,
407
+ guidance_scale: float = 1.0,
408
+ speed: float = 1.0,
409
+ t_shift: float = 0.5,
410
+ target_rms: float = 0.1,
411
+ feat_scale: float = 0.1,
412
+ sampling_rate: int = 24000,
413
+ ):
414
+ total_t = []
415
+ total_t_no_vocoder = []
416
+ total_t_vocoder = []
417
+ total_wav_seconds = []
418
+
419
+ with open(test_list, "r") as fr:
420
+ lines = fr.readlines()
421
+
422
+ for i, line in enumerate(lines):
423
+ wav_name, prompt_text, prompt_wav, text = line.strip().split("\t")
424
+ save_path = f"{res_dir}/{wav_name}.wav"
425
+ metrics = generate_sentence(
426
+ save_path=save_path,
427
+ prompt_text=prompt_text,
428
+ prompt_wav=prompt_wav,
429
+ text=text,
430
+ model=model,
431
+ vocoder=vocoder,
432
+ tokenizer=tokenizer,
433
+ feature_extractor=feature_extractor,
434
+ device=device,
435
+ num_step=num_step,
436
+ guidance_scale=guidance_scale,
437
+ speed=speed,
438
+ t_shift=t_shift,
439
+ target_rms=target_rms,
440
+ feat_scale=feat_scale,
441
+ sampling_rate=sampling_rate,
442
+ )
443
+ print(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}")
444
+ total_t.append(metrics["t"])
445
+ total_t_no_vocoder.append(metrics["t_no_vocoder"])
446
+ total_t_vocoder.append(metrics["t_vocoder"])
447
+ total_wav_seconds.append(metrics["wav_seconds"])
448
+
449
+ print(f"Average RTF: {np.sum(total_t) / np.sum(total_wav_seconds):.4f}")
450
+ print(
451
+ f"Average RTF w/o vocoder: "
452
+ f"{np.sum(total_t_no_vocoder) / np.sum(total_wav_seconds):.4f}"
453
+ )
454
+ print(
455
+ f"Average RTF vocoder: "
456
+ f"{np.sum(total_t_vocoder) / np.sum(total_wav_seconds):.4f}"
457
+ )
458
+
459
+
460
+ @torch.inference_mode()
461
+ def main():
462
+ parser = get_parser()
463
+ args = parser.parse_args()
464
+
465
+ params = AttributeDict()
466
+ params.update(vars(args))
467
+ fix_random_seed(params.seed)
468
+
469
+ model_defaults = {
470
+ "zipvoice": {
471
+ "num_step": 16,
472
+ "guidance_scale": 1.0,
473
+ },
474
+ "zipvoice_distill": {
475
+ "num_step": 8,
476
+ "guidance_scale": 3.0,
477
+ },
478
+ }
479
+
480
+ model_specific_defaults = model_defaults.get(params.model_name, {})
481
+
482
+ for param, value in model_specific_defaults.items():
483
+ if getattr(params, param) is None:
484
+ setattr(params, param, value)
485
+ print(f"Setting {param} to default value: {value}")
486
+
487
+ assert (params.test_list is not None) ^ (
488
+ (params.prompt_wav and params.prompt_text and params.text) is not None
489
+ ), (
490
+ "For inference, please provide prompts and text with either '--test-list'"
491
+ " or '--prompt-wav, --prompt-text and --text'."
492
+ )
493
+
494
+ if torch.cuda.is_available():
495
+ params.device = torch.device("cuda", 0)
496
+ elif torch.backends.mps.is_available():
497
+ params.device = torch.device("mps")
498
+ else:
499
+ params.device = torch.device("cpu")
500
+
501
+ print("Loading model...")
502
+ if params.model_config is None:
503
+ model_config = hf_hub_download(
504
+ HUGGINGFACE_REPO, filename=MODEL_CONFIG[params.model_name]
505
+ )
506
+ else:
507
+ model_config = params.model_config
508
+
509
+ with open(model_config, "r") as f:
510
+ model_config = json.load(f)
511
+
512
+ if params.token_file is None:
513
+ token_file = hf_hub_download(
514
+ HUGGINGFACE_REPO, filename=TOKEN_FILE[params.model_name]
515
+ )
516
+ else:
517
+ token_file = params.token_file
518
+
519
+ if params.tokenizer == "emilia":
520
+ tokenizer = EmiliaTokenizer(token_file=token_file)
521
+ elif params.tokenizer == "libritts":
522
+ tokenizer = LibriTTSTokenizer(token_file=token_file)
523
+ elif params.tokenizer == "espeak":
524
+ tokenizer = EspeakTokenizer(token_file=token_file, lang=params.lang)
525
+ else:
526
+ assert params.tokenizer == "simple"
527
+ tokenizer = SimpleTokenizer(token_file=token_file)
528
+
529
+ tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id}
530
+
531
+ if params.checkpoint is None:
532
+ model_ckpt = hf_hub_download(
533
+ HUGGINGFACE_REPO,
534
+ filename=PRETRAINED_MODEL[params.model_name],
535
+ )
536
+ else:
537
+ model_ckpt = params.checkpoint
538
+
539
+ if params.model_name == "zipvoice":
540
+ model = ZipVoice(
541
+ **model_config["model"],
542
+ **tokenizer_config,
543
+ )
544
+ else:
545
+ assert params.model_name == "zipvoice_distill"
546
+ model = ZipVoiceDistill(
547
+ **model_config["model"],
548
+ **tokenizer_config,
549
+ )
550
+
551
+ if model_ckpt.endswith(".safetensors"):
552
+ safetensors.torch.load_model(model, model_ckpt)
553
+ elif model_ckpt.endswith(".pt"):
554
+ load_checkpoint(filename=model_ckpt, model=model, strict=True)
555
+ else:
556
+ raise NotImplementedError(f"Unsupported model checkpoint format: {model_ckpt}")
557
+
558
+ model = model.to(params.device)
559
+ model.eval()
560
+
561
+ vocoder = get_vocoder(params.vocoder_path)
562
+ vocoder = vocoder.to(params.device)
563
+ vocoder.eval()
564
+
565
+ if model_config["feature"]["type"] == "vocos":
566
+ feature_extractor = VocosFbank()
567
+ else:
568
+ raise NotImplementedError(
569
+ f"Unsupported feature type: {model_config['feature']['type']}"
570
+ )
571
+ params.sampling_rate = model_config["feature"]["sampling_rate"]
572
+
573
+ print("Start generating...")
574
+ if params.test_list:
575
+ os.makedirs(params.res_dir, exist_ok=True)
576
+ generate_list(
577
+ res_dir=params.res_dir,
578
+ test_list=params.test_list,
579
+ model=model,
580
+ vocoder=vocoder,
581
+ tokenizer=tokenizer,
582
+ feature_extractor=feature_extractor,
583
+ device=params.device,
584
+ num_step=params.num_step,
585
+ guidance_scale=params.guidance_scale,
586
+ speed=params.speed,
587
+ t_shift=params.t_shift,
588
+ target_rms=params.target_rms,
589
+ feat_scale=params.feat_scale,
590
+ sampling_rate=params.sampling_rate,
591
+ )
592
+ else:
593
+ generate_sentence(
594
+ save_path=params.res_wav_path,
595
+ prompt_text=params.prompt_text,
596
+ prompt_wav=params.prompt_wav,
597
+ text=params.text,
598
+ model=model,
599
+ vocoder=vocoder,
600
+ tokenizer=tokenizer,
601
+ feature_extractor=feature_extractor,
602
+ device=params.device,
603
+ num_step=params.num_step,
604
+ guidance_scale=params.guidance_scale,
605
+ speed=params.speed,
606
+ t_shift=params.t_shift,
607
+ target_rms=params.target_rms,
608
+ feat_scale=params.feat_scale,
609
+ sampling_rate=params.sampling_rate,
610
+ )
611
+ print("Done")
612
+
613
+
614
+ if __name__ == "__main__":
615
+ torch.set_num_threads(1)
616
+ torch.set_num_interop_threads(1)
617
+ main()
zipvoice/bin/infer_zipvoice_dialog.py ADDED
@@ -0,0 +1,756 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ This script generates speech with our pre-trained ZipVoice-Dialog or
20
+ ZipVoice-Dialog-Stereo models. If no local model is specified,
21
+ Required files will be automatically downloaded from HuggingFace.
22
+
23
+ Usage:
24
+
25
+ Note: If you having trouble connecting to HuggingFace,
26
+ try switching endpoint to mirror site:
27
+ export HF_ENDPOINT=https://hf-mirror.com
28
+
29
+ python3 -m zipvoice.bin.infer_zipvoice_dialog \
30
+ --model-name "zipvoice_dialog" \
31
+ --test-list test.tsv \
32
+ --res-dir results
33
+
34
+ `--model-name` can be `zipvoice_dialog` or `zipvoice_dialog_stereo`,
35
+ which generate mono and stereo dialogues, respectively.
36
+
37
+ Each line of `test.tsv` is in the format of merged conversation:
38
+ '{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}'
39
+ or splited conversation:
40
+ '{wav_name}\t{spk1_prompt_transcription}\t{spk2_prompt_transcription}
41
+ \t{spk1_prompt_wav}\t{spk2_prompt_wav}\t{text}'
42
+ """
43
+
44
+ import argparse
45
+ import datetime as dt
46
+ import json
47
+ import os
48
+ from typing import List, Optional, Union
49
+
50
+ import numpy as np
51
+ import safetensors.torch
52
+ import torch
53
+ import torchaudio
54
+ from huggingface_hub import hf_hub_download
55
+ from lhotse.utils import fix_random_seed
56
+ from vocos import Vocos
57
+
58
+ from zipvoice.models.zipvoice_dialog import ZipVoiceDialog, ZipVoiceDialogStereo
59
+ from zipvoice.tokenizer.tokenizer import DialogTokenizer
60
+ from zipvoice.utils.checkpoint import load_checkpoint
61
+ from zipvoice.utils.common import AttributeDict
62
+ from zipvoice.utils.feature import VocosFbank
63
+
64
+ HUGGINGFACE_REPO = "k2-fsa/ZipVoice"
65
+ PRETRAINED_MODEL = {
66
+ "zipvoice_dialog": "zipvoice_dialog/model.pt",
67
+ "zipvoice_dialog_stereo": "zipvoice_dialog_stereo/model.pt",
68
+ }
69
+ TOKEN_FILE = {
70
+ "zipvoice_dialog": "zipvoice_dialog/tokens.txt",
71
+ "zipvoice_dialog_stereo": "zipvoice_dialog_stereo/tokens.txt",
72
+ }
73
+ MODEL_CONFIG = {
74
+ "zipvoice_dialog": "zipvoice_dialog/model.json",
75
+ "zipvoice_dialog_stereo": "zipvoice_dialog_stereo/model.json",
76
+ }
77
+
78
+
79
+ def get_parser():
80
+ parser = argparse.ArgumentParser(
81
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
82
+ )
83
+
84
+ parser.add_argument(
85
+ "--model-name",
86
+ type=str,
87
+ default="zipvoice_dialog",
88
+ choices=["zipvoice_dialog", "zipvoice_dialog_stereo"],
89
+ help="The model used for inference",
90
+ )
91
+
92
+ parser.add_argument(
93
+ "--checkpoint",
94
+ type=str,
95
+ default=None,
96
+ help="The model checkpoint. "
97
+ "Will download pre-trained checkpoint from huggingface if not specified.",
98
+ )
99
+
100
+ parser.add_argument(
101
+ "--model-config",
102
+ type=str,
103
+ default=None,
104
+ help="The model configuration file. "
105
+ "Will download model.json from huggingface if not specified.",
106
+ )
107
+
108
+ parser.add_argument(
109
+ "--vocoder-path",
110
+ type=str,
111
+ default=None,
112
+ help="The vocoder checkpoint. "
113
+ "Will download pre-trained vocoder from huggingface if not specified.",
114
+ )
115
+
116
+ parser.add_argument(
117
+ "--token-file",
118
+ type=str,
119
+ default=None,
120
+ help="The file that contains information that maps tokens to ids,"
121
+ "which is a text file with '{token}\t{token_id}' per line. "
122
+ "Will download tokens_emilia.txt from huggingface if not specified.",
123
+ )
124
+
125
+ parser.add_argument(
126
+ "--test-list",
127
+ type=str,
128
+ default=None,
129
+ help="The list of prompt speech, prompt_transcription, "
130
+ "and text to synthesizein the format of merged conversation: "
131
+ "'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}' "
132
+ "or splited conversation: "
133
+ "'{wav_name}\t{spk1_prompt_transcription}\t{spk2_prompt_transcription}"
134
+ "\t{spk1_prompt_wav}\t{spk2_prompt_wav}\t{text}'.",
135
+ )
136
+
137
+ parser.add_argument(
138
+ "--res-dir",
139
+ type=str,
140
+ default="results",
141
+ help="""
142
+ Path name of the generated wavs dir,
143
+ used when test-list is not None
144
+ """,
145
+ )
146
+
147
+ parser.add_argument(
148
+ "--guidance-scale",
149
+ type=float,
150
+ default=1.5,
151
+ help="The scale of classifier-free guidance during inference.",
152
+ )
153
+
154
+ parser.add_argument(
155
+ "--num-step",
156
+ type=int,
157
+ default=16,
158
+ help="The number of sampling steps.",
159
+ )
160
+
161
+ parser.add_argument(
162
+ "--feat-scale",
163
+ type=float,
164
+ default=0.1,
165
+ help="The scale factor of fbank feature",
166
+ )
167
+
168
+ parser.add_argument(
169
+ "--speed",
170
+ type=float,
171
+ default=1.0,
172
+ help="Control speech speed, 1.0 means normal, >1.0 means speed up",
173
+ )
174
+
175
+ parser.add_argument(
176
+ "--t-shift",
177
+ type=float,
178
+ default=0.5,
179
+ help="Shift t to smaller ones if t_shift < 1.0",
180
+ )
181
+
182
+ parser.add_argument(
183
+ "--target-rms",
184
+ type=float,
185
+ default=0.1,
186
+ help="Target speech normalization rms value, set to 0 to disable normalization",
187
+ )
188
+
189
+ parser.add_argument(
190
+ "--seed",
191
+ type=int,
192
+ default=666,
193
+ help="Random seed",
194
+ )
195
+
196
+ parser.add_argument(
197
+ "--silence-wav",
198
+ type=str,
199
+ default="assets/silence.wav",
200
+ help="Path of the silence wav file, used in two-channel generation "
201
+ "with single-channel prompts",
202
+ )
203
+
204
+ return parser
205
+
206
+
207
+ def get_vocoder(vocos_local_path: Optional[str] = None):
208
+ if vocos_local_path:
209
+ vocoder = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
210
+ state_dict = torch.load(
211
+ f"{vocos_local_path}/pytorch_model.bin",
212
+ weights_only=True,
213
+ map_location="cpu",
214
+ )
215
+ vocoder.load_state_dict(state_dict)
216
+ else:
217
+ vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz")
218
+ return vocoder
219
+
220
+
221
+ def generate_sentence(
222
+ save_path: str,
223
+ prompt_text: str,
224
+ prompt_wav: Union[str, List[str]],
225
+ text: str,
226
+ model: torch.nn.Module,
227
+ vocoder: torch.nn.Module,
228
+ tokenizer: DialogTokenizer,
229
+ feature_extractor: VocosFbank,
230
+ device: torch.device,
231
+ num_step: int = 16,
232
+ guidance_scale: float = 1.0,
233
+ speed: float = 1.0,
234
+ t_shift: float = 0.5,
235
+ target_rms: float = 0.1,
236
+ feat_scale: float = 0.1,
237
+ sampling_rate: int = 24000,
238
+ ):
239
+ """
240
+ Generate waveform of a text based on a given prompt
241
+ waveform and its transcription.
242
+
243
+ Args:
244
+ save_path (str): Path to save the generated wav.
245
+ prompt_text (str): Transcription of the prompt wav.
246
+ prompt_wav (Union[str, List[str]]): Path to the prompt wav file, can be
247
+ one or two wav files, which corresponding to a merged conversational
248
+ speech or two seperate speaker's speech.
249
+ text (str): Text to be synthesized into a waveform.
250
+ model (torch.nn.Module): The model used for generation.
251
+ vocoder (torch.nn.Module): The vocoder used to convert features to waveforms.
252
+ tokenizer (DialogTokenizer): The tokenizer used to convert text to tokens.
253
+ feature_extractor (VocosFbank): The feature extractor used to
254
+ extract acoustic features.
255
+ device (torch.device): The device on which computations are performed.
256
+ num_step (int, optional): Number of steps for decoding. Defaults to 16.
257
+ guidance_scale (float, optional): Scale for classifier-free guidance.
258
+ Defaults to 1.0.
259
+ speed (float, optional): Speed control. Defaults to 1.0.
260
+ t_shift (float, optional): Time shift. Defaults to 0.5.
261
+ target_rms (float, optional): Target RMS for waveform normalization.
262
+ Defaults to 0.1.
263
+ feat_scale (float, optional): Scale for features.
264
+ Defaults to 0.1.
265
+ sampling_rate (int, optional): Sampling rate for the waveform.
266
+ Defaults to 24000.
267
+ Returns:
268
+ metrics (dict): Dictionary containing time and real-time
269
+ factor metrics for processing.
270
+ """
271
+ # Convert text to tokens
272
+ tokens = tokenizer.texts_to_token_ids([text])
273
+ prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
274
+
275
+ # Load and preprocess prompt wav
276
+ if isinstance(prompt_wav, str):
277
+ prompt_wav = [
278
+ prompt_wav,
279
+ ]
280
+ else:
281
+ assert len(prompt_wav) == 2 and isinstance(prompt_wav[0], str)
282
+
283
+ loaded_prompt_wavs = prompt_wav
284
+ for i in range(len(prompt_wav)):
285
+ loaded_prompt_wavs[i], prompt_sampling_rate = torchaudio.load(prompt_wav[i])
286
+ if prompt_sampling_rate != sampling_rate:
287
+ resampler = torchaudio.transforms.Resample(
288
+ orig_freq=prompt_sampling_rate, new_freq=sampling_rate
289
+ )
290
+ loaded_prompt_wavs[i] = resampler(loaded_prompt_wavs[i])
291
+
292
+ if len(loaded_prompt_wavs) == 1:
293
+ prompt_wav = loaded_prompt_wavs[0]
294
+ else:
295
+ prompt_wav = torch.cat(loaded_prompt_wavs, dim=1)
296
+
297
+ prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav)))
298
+ if prompt_rms < target_rms:
299
+ prompt_wav = prompt_wav * target_rms / prompt_rms
300
+
301
+ # Extract features from prompt wav
302
+ prompt_features = feature_extractor.extract(
303
+ prompt_wav, sampling_rate=sampling_rate
304
+ ).to(device)
305
+
306
+ prompt_features = prompt_features.unsqueeze(0) * feat_scale
307
+ prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device)
308
+
309
+ # Start timing
310
+ start_t = dt.datetime.now()
311
+
312
+ # Generate features
313
+ (
314
+ pred_features,
315
+ pred_features_lens,
316
+ pred_prompt_features,
317
+ pred_prompt_features_lens,
318
+ ) = model.sample(
319
+ tokens=tokens,
320
+ prompt_tokens=prompt_tokens,
321
+ prompt_features=prompt_features,
322
+ prompt_features_lens=prompt_features_lens,
323
+ speed=speed,
324
+ t_shift=t_shift,
325
+ duration="predict",
326
+ num_step=num_step,
327
+ guidance_scale=guidance_scale,
328
+ )
329
+
330
+ # Postprocess predicted features
331
+ pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
332
+
333
+ # Start vocoder processing
334
+ start_vocoder_t = dt.datetime.now()
335
+ wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)
336
+
337
+ # Calculate processing times and real-time factors
338
+ t = (dt.datetime.now() - start_t).total_seconds()
339
+ t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
340
+ t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
341
+ wav_seconds = wav.shape[-1] / sampling_rate
342
+ rtf = t / wav_seconds
343
+ rtf_no_vocoder = t_no_vocoder / wav_seconds
344
+ rtf_vocoder = t_vocoder / wav_seconds
345
+ metrics = {
346
+ "t": t,
347
+ "t_no_vocoder": t_no_vocoder,
348
+ "t_vocoder": t_vocoder,
349
+ "wav_seconds": wav_seconds,
350
+ "rtf": rtf,
351
+ "rtf_no_vocoder": rtf_no_vocoder,
352
+ "rtf_vocoder": rtf_vocoder,
353
+ }
354
+
355
+ # Adjust wav volume if necessary
356
+ if prompt_rms < target_rms:
357
+ wav = wav * prompt_rms / target_rms
358
+ torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate)
359
+
360
+ return metrics
361
+
362
+
363
+ def generate_sentence_stereo(
364
+ save_path: str,
365
+ prompt_text: str,
366
+ prompt_wav: Union[str, List[str]],
367
+ text: str,
368
+ model: torch.nn.Module,
369
+ vocoder: torch.nn.Module,
370
+ tokenizer: DialogTokenizer,
371
+ feature_extractor: VocosFbank,
372
+ device: torch.device,
373
+ num_step: int = 16,
374
+ guidance_scale: float = 1.0,
375
+ speed: float = 1.0,
376
+ t_shift: float = 0.5,
377
+ target_rms: float = 0.1,
378
+ feat_scale: float = 0.1,
379
+ sampling_rate: int = 24000,
380
+ silence_wav: Optional[str] = None,
381
+ ):
382
+ """
383
+ Generate waveform of a text based on a given prompt
384
+ waveform and its transcription.
385
+
386
+ Args:
387
+ save_path (str): Path to save the generated wav.
388
+ prompt_text (str): Transcription of the prompt wav.
389
+ prompt_wav (Union[str, List[str]]): Path to the prompt wav file, can be
390
+ one or two wav files, which corresponding to a merged conversational
391
+ speech or two seperate speaker's speech.
392
+ text (str): Text to be synthesized into a waveform.
393
+ model (torch.nn.Module): The model used for generation.
394
+ vocoder (torch.nn.Module): The vocoder used to convert features to waveforms.
395
+ tokenizer (DialogTokenizer): The tokenizer used to convert text to tokens.
396
+ feature_extractor (VocosFbank): The feature extractor used to
397
+ extract acoustic features.
398
+ device (torch.device): The device on which computations are performed.
399
+ num_step (int, optional): Number of steps for decoding. Defaults to 16.
400
+ guidance_scale (float, optional): Scale for classifier-free guidance.
401
+ Defaults to 1.0.
402
+ speed (float, optional): Speed control. Defaults to 1.0.
403
+ t_shift (float, optional): Time shift. Defaults to 0.5.
404
+ target_rms (float, optional): Target RMS for waveform normalization.
405
+ Defaults to 0.1.
406
+ feat_scale (float, optional): Scale for features.
407
+ Defaults to 0.1.
408
+ sampling_rate (int, optional): Sampling rate for the waveform.
409
+ Defaults to 24000.
410
+ silence_wav (str): Path of the silence wav file, used in two-channel
411
+ generation with single-channel prompts
412
+ Returns:
413
+ metrics (dict): Dictionary containing time and real-time
414
+ factor metrics for processing.
415
+ """
416
+ # Convert text to tokens
417
+ tokens = tokenizer.texts_to_token_ids([text])
418
+ prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
419
+
420
+ # Load and preprocess prompt wav
421
+ if isinstance(prompt_wav, str):
422
+ prompt_wav = [
423
+ prompt_wav,
424
+ ]
425
+ else:
426
+ assert len(prompt_wav) == 2 and isinstance(prompt_wav[0], str)
427
+
428
+ loaded_prompt_wavs = prompt_wav
429
+ for i in range(len(prompt_wav)):
430
+ loaded_prompt_wavs[i], prompt_sampling_rate = torchaudio.load(prompt_wav[i])
431
+ if prompt_sampling_rate != sampling_rate:
432
+ resampler = torchaudio.transforms.Resample(
433
+ orig_freq=prompt_sampling_rate, new_freq=sampling_rate
434
+ )
435
+ loaded_prompt_wavs[i] = resampler(loaded_prompt_wavs[i])
436
+
437
+ if len(loaded_prompt_wavs) == 1:
438
+ assert (
439
+ loaded_prompt_wavs[0].size(0) == 2
440
+ ), "Merged prompt wav must be stereo for stereo dialogue generation"
441
+ prompt_wav = loaded_prompt_wavs[0]
442
+
443
+ else:
444
+ assert len(loaded_prompt_wavs) == 2
445
+ if loaded_prompt_wavs[0].size(0) == 2:
446
+ prompt_wav = torch.cat(loaded_prompt_wavs, dim=1)
447
+ else:
448
+ assert loaded_prompt_wavs[0].size(0) == 1
449
+ silence_wav, silence_sampling_rate = torchaudio.load(silence_wav)
450
+ assert silence_sampling_rate == sampling_rate
451
+ prompt_wav = silence_wav[
452
+ :, : loaded_prompt_wavs[0].size(1) + loaded_prompt_wavs[1].size(1)
453
+ ]
454
+ prompt_wav[0, : loaded_prompt_wavs[0].size(1)] = loaded_prompt_wavs[0]
455
+ prompt_wav[1, loaded_prompt_wavs[0].size(1) :] = loaded_prompt_wavs[1]
456
+
457
+ prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav)))
458
+ if prompt_rms < target_rms:
459
+ prompt_wav = prompt_wav * target_rms / prompt_rms
460
+
461
+ # Extract features from prompt wav
462
+ prompt_features = feature_extractor.extract(
463
+ prompt_wav, sampling_rate=sampling_rate
464
+ ).to(device)
465
+
466
+ prompt_features = prompt_features.unsqueeze(0) * feat_scale
467
+ prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device)
468
+
469
+ # Start timing
470
+ start_t = dt.datetime.now()
471
+
472
+ # Generate features
473
+ (
474
+ pred_features,
475
+ pred_features_lens,
476
+ pred_prompt_features,
477
+ pred_prompt_features_lens,
478
+ ) = model.sample(
479
+ tokens=tokens,
480
+ prompt_tokens=prompt_tokens,
481
+ prompt_features=prompt_features,
482
+ prompt_features_lens=prompt_features_lens,
483
+ speed=speed,
484
+ t_shift=t_shift,
485
+ duration="predict",
486
+ num_step=num_step,
487
+ guidance_scale=guidance_scale,
488
+ )
489
+
490
+ # Postprocess predicted features
491
+ pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
492
+
493
+ # Start vocoder processing
494
+ start_vocoder_t = dt.datetime.now()
495
+ feat_dim = pred_features.size(1) // 2
496
+ wav_left = vocoder.decode(pred_features[:, :feat_dim]).squeeze(1).clamp(-1, 1)
497
+ wav_right = (
498
+ vocoder.decode(pred_features[:, feat_dim : feat_dim * 2])
499
+ .squeeze(1)
500
+ .clamp(-1, 1)
501
+ )
502
+
503
+ wav = torch.cat([wav_left, wav_right], dim=0)
504
+
505
+ # Calculate processing times and real-time factors
506
+ t = (dt.datetime.now() - start_t).total_seconds()
507
+ t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
508
+ t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
509
+ wav_seconds = wav.shape[-1] / sampling_rate
510
+ rtf = t / wav_seconds
511
+ rtf_no_vocoder = t_no_vocoder / wav_seconds
512
+ rtf_vocoder = t_vocoder / wav_seconds
513
+ metrics = {
514
+ "t": t,
515
+ "t_no_vocoder": t_no_vocoder,
516
+ "t_vocoder": t_vocoder,
517
+ "wav_seconds": wav_seconds,
518
+ "rtf": rtf,
519
+ "rtf_no_vocoder": rtf_no_vocoder,
520
+ "rtf_vocoder": rtf_vocoder,
521
+ }
522
+
523
+ # Adjust wav volume if necessary
524
+ if prompt_rms < target_rms:
525
+ wav = wav * prompt_rms / target_rms
526
+ torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate)
527
+
528
+ return metrics
529
+
530
+
531
+ def generate_list(
532
+ model_name: str,
533
+ res_dir: str,
534
+ test_list: str,
535
+ model: torch.nn.Module,
536
+ vocoder: torch.nn.Module,
537
+ tokenizer: DialogTokenizer,
538
+ feature_extractor: VocosFbank,
539
+ device: torch.device,
540
+ num_step: int = 16,
541
+ guidance_scale: float = 1.5,
542
+ speed: float = 1.0,
543
+ t_shift: float = 0.5,
544
+ target_rms: float = 0.1,
545
+ feat_scale: float = 0.1,
546
+ sampling_rate: int = 24000,
547
+ silence_wav: Optional[str] = None,
548
+ ):
549
+ total_t = []
550
+ total_t_no_vocoder = []
551
+ total_t_vocoder = []
552
+ total_wav_seconds = []
553
+
554
+ with open(test_list, "r") as fr:
555
+ lines = fr.readlines()
556
+
557
+ for i, line in enumerate(lines):
558
+ items = line.strip().split("\t")
559
+ if len(items) == 6:
560
+ (
561
+ wav_name,
562
+ prompt_text_1,
563
+ prompt_text_2,
564
+ prompt_wav_1,
565
+ prompt_wav_2,
566
+ text,
567
+ ) = items
568
+ prompt_text = f"[S1]{prompt_text_1}[S2]{prompt_text_2}"
569
+ prompt_wav = [prompt_wav_1, prompt_wav_2]
570
+ elif len(items) == 4:
571
+ wav_name, prompt_text, prompt_wav, text = items
572
+ else:
573
+ raise ValueError(f"Invalid line: {line}")
574
+ assert text.startswith("[S1]")
575
+
576
+ save_path = f"{res_dir}/{wav_name}.wav"
577
+
578
+ if model_name == "zipvoice_dialog":
579
+
580
+ metrics = generate_sentence(
581
+ save_path=save_path,
582
+ prompt_text=prompt_text,
583
+ prompt_wav=prompt_wav,
584
+ text=text,
585
+ model=model,
586
+ vocoder=vocoder,
587
+ tokenizer=tokenizer,
588
+ feature_extractor=feature_extractor,
589
+ device=device,
590
+ num_step=num_step,
591
+ guidance_scale=guidance_scale,
592
+ speed=speed,
593
+ t_shift=t_shift,
594
+ target_rms=target_rms,
595
+ feat_scale=feat_scale,
596
+ sampling_rate=sampling_rate,
597
+ )
598
+ else:
599
+ assert model_name == "zipvoice_dialog_stereo"
600
+ metrics = generate_sentence_stereo(
601
+ save_path=save_path,
602
+ prompt_text=prompt_text,
603
+ prompt_wav=prompt_wav,
604
+ text=text,
605
+ model=model,
606
+ vocoder=vocoder,
607
+ tokenizer=tokenizer,
608
+ feature_extractor=feature_extractor,
609
+ device=device,
610
+ num_step=num_step,
611
+ guidance_scale=guidance_scale,
612
+ speed=speed,
613
+ t_shift=t_shift,
614
+ target_rms=target_rms,
615
+ feat_scale=feat_scale,
616
+ sampling_rate=sampling_rate,
617
+ silence_wav=silence_wav,
618
+ )
619
+
620
+ print(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}")
621
+ total_t.append(metrics["t"])
622
+ total_t_no_vocoder.append(metrics["t_no_vocoder"])
623
+ total_t_vocoder.append(metrics["t_vocoder"])
624
+ total_wav_seconds.append(metrics["wav_seconds"])
625
+
626
+ print(f"Average RTF: {np.sum(total_t) / np.sum(total_wav_seconds):.4f}")
627
+ print(
628
+ f"Average RTF w/o vocoder: "
629
+ f"{np.sum(total_t_no_vocoder) / np.sum(total_wav_seconds):.4f}"
630
+ )
631
+ print(
632
+ f"Average RTF vocoder: "
633
+ f"{np.sum(total_t_vocoder) / np.sum(total_wav_seconds):.4f}"
634
+ )
635
+
636
+
637
+ @torch.inference_mode()
638
+ def main():
639
+ parser = get_parser()
640
+ args = parser.parse_args()
641
+
642
+ params = AttributeDict()
643
+ params.update(vars(args))
644
+ fix_random_seed(params.seed)
645
+
646
+ assert (
647
+ params.test_list is not None
648
+ ), "For inference, please provide prompts and text with '--test-list'"
649
+
650
+ if torch.cuda.is_available():
651
+ params.device = torch.device("cuda", 0)
652
+ elif torch.backends.mps.is_available():
653
+ params.device = torch.device("mps")
654
+ else:
655
+ params.device = torch.device("cpu")
656
+
657
+ print("Loading model...")
658
+ if params.model_config is None:
659
+ model_config = hf_hub_download(
660
+ HUGGINGFACE_REPO, filename=MODEL_CONFIG[params.model_name]
661
+ )
662
+ else:
663
+ model_config = params.model_config
664
+
665
+ with open(model_config, "r") as f:
666
+ model_config = json.load(f)
667
+
668
+ if params.token_file is None:
669
+ token_file = hf_hub_download(
670
+ HUGGINGFACE_REPO, filename=TOKEN_FILE[params.model_name]
671
+ )
672
+ else:
673
+ token_file = params.token_file
674
+
675
+ tokenizer = DialogTokenizer(token_file=token_file)
676
+
677
+ tokenizer_config = {
678
+ "vocab_size": tokenizer.vocab_size,
679
+ "pad_id": tokenizer.pad_id,
680
+ "spk_a_id": tokenizer.spk_a_id,
681
+ "spk_b_id": tokenizer.spk_b_id,
682
+ }
683
+ if params.checkpoint is None:
684
+ model_ckpt = hf_hub_download(
685
+ HUGGINGFACE_REPO,
686
+ filename=PRETRAINED_MODEL[params.model_name],
687
+ )
688
+ else:
689
+ model_ckpt = params.checkpoint
690
+
691
+ if params.model_name == "zipvoice_dialog":
692
+ model = ZipVoiceDialog(
693
+ **model_config["model"],
694
+ **tokenizer_config,
695
+ )
696
+ else:
697
+ assert params.model_name == "zipvoice_dialog_stereo"
698
+ model = ZipVoiceDialogStereo(
699
+ **model_config["model"],
700
+ **tokenizer_config,
701
+ )
702
+
703
+ if model_ckpt.endswith(".safetensors"):
704
+ safetensors.torch.load_model(model, model_ckpt)
705
+ elif model_ckpt.endswith(".pt"):
706
+ load_checkpoint(filename=model_ckpt, model=model, strict=True)
707
+ else:
708
+ raise NotImplementedError(f"Unsupported model checkpoint format: {model_ckpt}")
709
+
710
+ model = model.to(params.device)
711
+ model.eval()
712
+
713
+ vocoder = get_vocoder(params.vocoder_path)
714
+ vocoder = vocoder.to(params.device)
715
+ vocoder.eval()
716
+
717
+ if model_config["feature"]["type"] == "vocos":
718
+ if params.model_name == "zipvoice_dialog":
719
+ num_channels = 1
720
+ else:
721
+ assert params.model_name == "zipvoice_dialog_stereo"
722
+ num_channels = 2
723
+ feature_extractor = VocosFbank(num_channels=num_channels)
724
+ else:
725
+ raise NotImplementedError(
726
+ f"Unsupported feature type: {model_config['feature']['type']}"
727
+ )
728
+ params.sampling_rate = model_config["feature"]["sampling_rate"]
729
+
730
+ print("Start generating...")
731
+ os.makedirs(params.res_dir, exist_ok=True)
732
+ generate_list(
733
+ model_name=params.model_name,
734
+ res_dir=params.res_dir,
735
+ test_list=params.test_list,
736
+ model=model,
737
+ vocoder=vocoder,
738
+ tokenizer=tokenizer,
739
+ feature_extractor=feature_extractor,
740
+ device=params.device,
741
+ num_step=params.num_step,
742
+ guidance_scale=params.guidance_scale,
743
+ speed=params.speed,
744
+ t_shift=params.t_shift,
745
+ target_rms=params.target_rms,
746
+ feat_scale=params.feat_scale,
747
+ sampling_rate=params.sampling_rate,
748
+ silence_wav=params.silence_wav,
749
+ )
750
+ print("Done")
751
+
752
+
753
+ if __name__ == "__main__":
754
+ torch.set_num_threads(1)
755
+ torch.set_num_interop_threads(1)
756
+ main()
zipvoice/bin/infer_zipvoice_onnx.py ADDED
@@ -0,0 +1,715 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Xiaomi Corp. (authors: Han Zhu,
2
+ # Zengwei Yao)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """
18
+ This script generates speech with our pre-trained ZipVoice or ZipVoice-Distill
19
+ ONNX models. If no local model is specified,
20
+ Required files will be automatically downloaded from HuggingFace.
21
+
22
+ Usage:
23
+
24
+ Note: If you having trouble connecting to HuggingFace,
25
+ try switching endpoint to mirror site:
26
+ export HF_ENDPOINT=https://hf-mirror.com
27
+
28
+ (1) Inference of a single sentence:
29
+
30
+ python3 -m zipvoice.bin.infer_zipvoice_onnx \
31
+ --onnx-int8 False \
32
+ --model-name "zipvoice" \
33
+ --prompt-wav prompt.wav \
34
+ --prompt-text "I am a prompt." \
35
+ --text "I am a sentence." \
36
+ --res-wav-path result.wav
37
+
38
+ (2) Inference of a list of sentences:
39
+ python3 -m zipvoice.bin.infer_zipvoice_onnx \
40
+ --onnx-int8 False \
41
+ --model-name "zipvoice" \
42
+ --test-list test.tsv \
43
+ --res-dir results
44
+
45
+ `--model-name` can be `zipvoice` or `zipvoice_distill`,
46
+ which are the models before and after distillation, respectively.
47
+
48
+ Each line of `test.tsv` is in the format of
49
+ `{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}`.
50
+
51
+ Set `--onnx-int8 True` to use int8 quantizated ONNX model.
52
+ """
53
+
54
+ import argparse
55
+ import datetime as dt
56
+ import json
57
+ import os
58
+ from typing import List, Tuple
59
+
60
+ import numpy as np
61
+ import onnxruntime as ort
62
+ import torch
63
+ import torchaudio
64
+ from huggingface_hub import hf_hub_download
65
+ from lhotse.utils import fix_random_seed
66
+ from torch import Tensor, nn
67
+
68
+ from zipvoice.bin.infer_zipvoice import get_vocoder
69
+ from zipvoice.models.modules.solver import get_time_steps
70
+ from zipvoice.tokenizer.tokenizer import (
71
+ EmiliaTokenizer,
72
+ EspeakTokenizer,
73
+ LibriTTSTokenizer,
74
+ SimpleTokenizer,
75
+ )
76
+ from zipvoice.utils.common import AttributeDict, str2bool
77
+ from zipvoice.utils.feature import VocosFbank
78
+
79
+ HUGGINGFACE_REPO = "k2-fsa/ZipVoice"
80
+ TOKEN_FILE = {
81
+ "zipvoice": "zipvoice/tokens.txt",
82
+ "zipvoice_distill": "zipvoice_distill/tokens.txt",
83
+ }
84
+ MODEL_CONFIG = {
85
+ "zipvoice": "zipvoice/model.json",
86
+ "zipvoice_distill": "zipvoice_distill/model.json",
87
+ }
88
+
89
+
90
+ def get_parser():
91
+ parser = argparse.ArgumentParser(
92
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
93
+ )
94
+
95
+ parser.add_argument(
96
+ "--onnx-int8",
97
+ type=str2bool,
98
+ default=False,
99
+ help="Whether to use the int8 model",
100
+ )
101
+
102
+ parser.add_argument(
103
+ "--model-name",
104
+ type=str,
105
+ default="zipvoice",
106
+ choices=["zipvoice", "zipvoice_distill"],
107
+ help="The model used for inference",
108
+ )
109
+
110
+ parser.add_argument(
111
+ "--onnx-model-dir",
112
+ type=str,
113
+ default=None,
114
+ help="The path to the local onnx model. "
115
+ "Will download pre-trained checkpoint from huggingface if not specified.",
116
+ )
117
+
118
+ parser.add_argument(
119
+ "--model-config",
120
+ type=str,
121
+ default=None,
122
+ help="The model configuration file. "
123
+ "Will download model.json from huggingface if not specified.",
124
+ )
125
+
126
+ parser.add_argument(
127
+ "--vocoder-path",
128
+ type=str,
129
+ default=None,
130
+ help="The vocoder checkpoint. "
131
+ "Will download pre-trained vocoder from huggingface if not specified.",
132
+ )
133
+
134
+ parser.add_argument(
135
+ "--token-file",
136
+ type=str,
137
+ default=None,
138
+ help="The file that contains information that maps tokens to ids,"
139
+ "which is a text file with '{token}\t{token_id}' per line. "
140
+ "Will download tokens_emilia.txt from huggingface if not specified.",
141
+ )
142
+
143
+ parser.add_argument(
144
+ "--tokenizer",
145
+ type=str,
146
+ default="emilia",
147
+ choices=["emilia", "libritts", "espeak", "simple"],
148
+ help="Tokenizer type.",
149
+ )
150
+
151
+ parser.add_argument(
152
+ "--lang",
153
+ type=str,
154
+ default="en-us",
155
+ help="Language identifier, used when tokenizer type is espeak. see"
156
+ "https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md",
157
+ )
158
+
159
+ parser.add_argument(
160
+ "--test-list",
161
+ type=str,
162
+ default=None,
163
+ help="The list of prompt speech, prompt_transcription, "
164
+ "and text to synthesizein the format of "
165
+ "'{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}'.",
166
+ )
167
+
168
+ parser.add_argument(
169
+ "--prompt-wav",
170
+ type=str,
171
+ default=None,
172
+ help="The prompt wav to mimic",
173
+ )
174
+
175
+ parser.add_argument(
176
+ "--prompt-text",
177
+ type=str,
178
+ default=None,
179
+ help="The transcription of the prompt wav",
180
+ )
181
+
182
+ parser.add_argument(
183
+ "--text",
184
+ type=str,
185
+ default=None,
186
+ help="The text to synthesize",
187
+ )
188
+
189
+ parser.add_argument(
190
+ "--res-dir",
191
+ type=str,
192
+ default="results",
193
+ help="""
194
+ Path name of the generated wavs dir,
195
+ used when test-list is not None
196
+ """,
197
+ )
198
+
199
+ parser.add_argument(
200
+ "--res-wav-path",
201
+ type=str,
202
+ default="result.wav",
203
+ help="""
204
+ Path name of the generated wav path,
205
+ used when test-list is None
206
+ """,
207
+ )
208
+
209
+ parser.add_argument(
210
+ "--guidance-scale",
211
+ type=float,
212
+ default=None,
213
+ help="The scale of classifier-free guidance during inference.",
214
+ )
215
+
216
+ parser.add_argument(
217
+ "--num-step",
218
+ type=int,
219
+ default=None,
220
+ help="The number of sampling steps.",
221
+ )
222
+
223
+ parser.add_argument(
224
+ "--feat-scale",
225
+ type=float,
226
+ default=0.1,
227
+ help="The scale factor of fbank feature",
228
+ )
229
+
230
+ parser.add_argument(
231
+ "--speed",
232
+ type=float,
233
+ default=1.0,
234
+ help="Control speech speed, 1.0 means normal, >1.0 means speed up",
235
+ )
236
+
237
+ parser.add_argument(
238
+ "--t-shift",
239
+ type=float,
240
+ default=0.5,
241
+ help="Shift t to smaller ones if t_shift < 1.0",
242
+ )
243
+
244
+ parser.add_argument(
245
+ "--target-rms",
246
+ type=float,
247
+ default=0.1,
248
+ help="Target speech normalization rms value, set to 0 to disable normalization",
249
+ )
250
+
251
+ parser.add_argument(
252
+ "--seed",
253
+ type=int,
254
+ default=666,
255
+ help="Random seed",
256
+ )
257
+
258
+ return parser
259
+
260
+
261
+ class OnnxModel:
262
+ def __init__(
263
+ self,
264
+ text_encoder_path: str,
265
+ fm_decoder_path: str,
266
+ ):
267
+ session_opts = ort.SessionOptions()
268
+ session_opts.inter_op_num_threads = 1
269
+ session_opts.intra_op_num_threads = 1
270
+
271
+ self.session_opts = session_opts
272
+
273
+ self.init_text_encoder(text_encoder_path)
274
+ self.init_fm_decoder(fm_decoder_path)
275
+
276
+ def init_text_encoder(self, model_path: str):
277
+ self.text_encoder = ort.InferenceSession(
278
+ model_path,
279
+ sess_options=self.session_opts,
280
+ providers=["CPUExecutionProvider"],
281
+ )
282
+
283
+ def init_fm_decoder(self, model_path: str):
284
+ self.fm_decoder = ort.InferenceSession(
285
+ model_path,
286
+ sess_options=self.session_opts,
287
+ providers=["CPUExecutionProvider"],
288
+ )
289
+ meta = self.fm_decoder.get_modelmeta().custom_metadata_map
290
+ self.feat_dim = int(meta["feat_dim"])
291
+
292
+ def run_text_encoder(
293
+ self,
294
+ tokens: Tensor,
295
+ prompt_tokens: Tensor,
296
+ prompt_features_len: Tensor,
297
+ speed: Tensor,
298
+ ) -> Tuple[Tensor, Tensor]:
299
+ out = self.text_encoder.run(
300
+ [
301
+ self.text_encoder.get_outputs()[0].name,
302
+ ],
303
+ {
304
+ self.text_encoder.get_inputs()[0].name: tokens.numpy(),
305
+ self.text_encoder.get_inputs()[1].name: prompt_tokens.numpy(),
306
+ self.text_encoder.get_inputs()[2].name: prompt_features_len.numpy(),
307
+ self.text_encoder.get_inputs()[3].name: speed.numpy(),
308
+ },
309
+ )
310
+ return torch.from_numpy(out[0])
311
+
312
+ def run_fm_decoder(
313
+ self,
314
+ t: Tensor,
315
+ x: Tensor,
316
+ text_condition: Tensor,
317
+ speech_condition: torch.Tensor,
318
+ guidance_scale: Tensor,
319
+ ) -> Tensor:
320
+ out = self.fm_decoder.run(
321
+ [
322
+ self.fm_decoder.get_outputs()[0].name,
323
+ ],
324
+ {
325
+ self.fm_decoder.get_inputs()[0].name: t.numpy(),
326
+ self.fm_decoder.get_inputs()[1].name: x.numpy(),
327
+ self.fm_decoder.get_inputs()[2].name: text_condition.numpy(),
328
+ self.fm_decoder.get_inputs()[3].name: speech_condition.numpy(),
329
+ self.fm_decoder.get_inputs()[4].name: guidance_scale.numpy(),
330
+ },
331
+ )
332
+ return torch.from_numpy(out[0])
333
+
334
+
335
+ def sample(
336
+ model: OnnxModel,
337
+ tokens: List[List[int]],
338
+ prompt_tokens: List[List[int]],
339
+ prompt_features: Tensor,
340
+ speed: float = 1.0,
341
+ t_shift: float = 0.5,
342
+ guidance_scale: float = 1.0,
343
+ num_step: int = 16,
344
+ ) -> torch.Tensor:
345
+ """
346
+ Generate acoustic features, given text tokens, prompts feature and prompt
347
+ transcription's text tokens.
348
+
349
+ Args:
350
+ tokens: a list of list of text tokens.
351
+ prompt_tokens: a list of list of prompt tokens.
352
+ prompt_features: the prompt feature with the shape
353
+ (batch_size, seq_len, feat_dim).
354
+ speed : speed control.
355
+ t_shift: time shift.
356
+ guidance_scale: the guidance scale for classifier-free guidance.
357
+ num_step: the number of steps to use in the ODE solver.
358
+ """
359
+ # Run text encoder
360
+ assert len(tokens) == len(prompt_tokens) == 1
361
+ tokens = torch.tensor(tokens, dtype=torch.int64)
362
+ prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.int64)
363
+ prompt_features_len = torch.tensor(prompt_features.size(1), dtype=torch.int64)
364
+ speed = torch.tensor(speed, dtype=torch.float32)
365
+
366
+ text_condition = model.run_text_encoder(
367
+ tokens, prompt_tokens, prompt_features_len, speed
368
+ )
369
+
370
+ batch_size, num_frames, _ = text_condition.shape
371
+ assert batch_size == 1
372
+ feat_dim = model.feat_dim
373
+
374
+ # Run flow matching model
375
+ timesteps = get_time_steps(
376
+ t_start=0.0,
377
+ t_end=1.0,
378
+ num_step=num_step,
379
+ t_shift=t_shift,
380
+ )
381
+ x = torch.randn(batch_size, num_frames, feat_dim)
382
+ speech_condition = torch.nn.functional.pad(
383
+ prompt_features, (0, 0, 0, num_frames - prompt_features.shape[1])
384
+ ) # (B, T, F)
385
+ guidance_scale = torch.tensor(guidance_scale, dtype=torch.float32)
386
+
387
+ for step in range(num_step):
388
+ v = model.run_fm_decoder(
389
+ t=timesteps[step],
390
+ x=x,
391
+ text_condition=text_condition,
392
+ speech_condition=speech_condition,
393
+ guidance_scale=guidance_scale,
394
+ )
395
+ x = x + v * (timesteps[step + 1] - timesteps[step])
396
+
397
+ x = x[:, prompt_features_len.item() :, :]
398
+ return x
399
+
400
+
401
+ # Copied from zipvoice/infer/infer_zipvoice.py, but call an external sample function
402
+ def generate_sentence(
403
+ save_path: str,
404
+ prompt_text: str,
405
+ prompt_wav: str,
406
+ text: str,
407
+ model: OnnxModel,
408
+ vocoder: nn.Module,
409
+ tokenizer: EmiliaTokenizer,
410
+ feature_extractor: VocosFbank,
411
+ num_step: int = 16,
412
+ guidance_scale: float = 1.0,
413
+ speed: float = 1.0,
414
+ t_shift: float = 0.5,
415
+ target_rms: float = 0.1,
416
+ feat_scale: float = 0.1,
417
+ sampling_rate: int = 24000,
418
+ ):
419
+ """
420
+ Generate waveform of a text based on a given prompt
421
+ waveform and its transcription.
422
+
423
+ Args:
424
+ save_path (str): Path to save the generated wav.
425
+ prompt_text (str): Transcription of the prompt wav.
426
+ prompt_wav (str): Path to the prompt wav file.
427
+ text (str): Text to be synthesized into a waveform.
428
+ model (torch.nn.Module): The model used for generation.
429
+ vocoder (torch.nn.Module): The vocoder used to convert features to waveforms.
430
+ tokenizer (EmiliaTokenizer): The tokenizer used to convert text to tokens.
431
+ feature_extractor (VocosFbank): The feature extractor used to
432
+ extract acoustic features.
433
+ num_step (int, optional): Number of steps for decoding. Defaults to 16.
434
+ guidance_scale (float, optional): Scale for classifier-free guidance.
435
+ Defaults to 1.0.
436
+ speed (float, optional): Speed control. Defaults to 1.0.
437
+ t_shift (float, optional): Time shift. Defaults to 0.5.
438
+ target_rms (float, optional): Target RMS for waveform normalization.
439
+ Defaults to 0.1.
440
+ feat_scale (float, optional): Scale for features.
441
+ Defaults to 0.1.
442
+ sampling_rate (int, optional): Sampling rate for the waveform.
443
+ Defaults to 24000.
444
+ Returns:
445
+ metrics (dict): Dictionary containing time and real-time
446
+ factor metrics for processing.
447
+ """
448
+ # Convert text to tokens
449
+ tokens = tokenizer.texts_to_token_ids([text])
450
+ prompt_tokens = tokenizer.texts_to_token_ids([prompt_text])
451
+
452
+ # Load and preprocess prompt wav
453
+ prompt_wav, prompt_sampling_rate = torchaudio.load(prompt_wav)
454
+
455
+ if prompt_sampling_rate != sampling_rate:
456
+ resampler = torchaudio.transforms.Resample(
457
+ orig_freq=prompt_sampling_rate, new_freq=sampling_rate
458
+ )
459
+ prompt_wav = resampler(prompt_wav)
460
+
461
+ prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav)))
462
+ if prompt_rms < target_rms:
463
+ prompt_wav = prompt_wav * target_rms / prompt_rms
464
+
465
+ # Extract features from prompt wav
466
+ prompt_features = feature_extractor.extract(prompt_wav, sampling_rate=sampling_rate)
467
+
468
+ prompt_features = prompt_features.unsqueeze(0) * feat_scale
469
+
470
+ # Start timing
471
+ start_t = dt.datetime.now()
472
+
473
+ # Generate features
474
+ pred_features = sample(
475
+ model=model,
476
+ tokens=tokens,
477
+ prompt_tokens=prompt_tokens,
478
+ prompt_features=prompt_features,
479
+ speed=speed,
480
+ t_shift=t_shift,
481
+ guidance_scale=guidance_scale,
482
+ num_step=num_step,
483
+ )
484
+
485
+ # Postprocess predicted features
486
+ pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T)
487
+
488
+ # Start vocoder processing
489
+ start_vocoder_t = dt.datetime.now()
490
+ wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1)
491
+
492
+ # Calculate processing times and real-time factors
493
+ t = (dt.datetime.now() - start_t).total_seconds()
494
+ t_no_vocoder = (start_vocoder_t - start_t).total_seconds()
495
+ t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds()
496
+ wav_seconds = wav.shape[-1] / sampling_rate
497
+ rtf = t / wav_seconds
498
+ rtf_no_vocoder = t_no_vocoder / wav_seconds
499
+ rtf_vocoder = t_vocoder / wav_seconds
500
+ metrics = {
501
+ "t": t,
502
+ "t_no_vocoder": t_no_vocoder,
503
+ "t_vocoder": t_vocoder,
504
+ "wav_seconds": wav_seconds,
505
+ "rtf": rtf,
506
+ "rtf_no_vocoder": rtf_no_vocoder,
507
+ "rtf_vocoder": rtf_vocoder,
508
+ }
509
+
510
+ # Adjust wav volume if necessary
511
+ if prompt_rms < target_rms:
512
+ wav = wav * prompt_rms / target_rms
513
+ torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate)
514
+
515
+ return metrics
516
+
517
+
518
+ def generate_list(
519
+ res_dir: str,
520
+ test_list: str,
521
+ model: OnnxModel,
522
+ vocoder: nn.Module,
523
+ tokenizer: EmiliaTokenizer,
524
+ feature_extractor: VocosFbank,
525
+ num_step: int = 16,
526
+ guidance_scale: float = 1.0,
527
+ speed: float = 1.0,
528
+ t_shift: float = 0.5,
529
+ target_rms: float = 0.1,
530
+ feat_scale: float = 0.1,
531
+ sampling_rate: int = 24000,
532
+ ):
533
+ total_t = []
534
+ total_t_no_vocoder = []
535
+ total_t_vocoder = []
536
+ total_wav_seconds = []
537
+
538
+ with open(test_list, "r") as fr:
539
+ lines = fr.readlines()
540
+
541
+ for i, line in enumerate(lines):
542
+ wav_name, prompt_text, prompt_wav, text = line.strip().split("\t")
543
+ save_path = f"{res_dir}/{wav_name}.wav"
544
+ metrics = generate_sentence(
545
+ save_path=save_path,
546
+ prompt_text=prompt_text,
547
+ prompt_wav=prompt_wav,
548
+ text=text,
549
+ model=model,
550
+ vocoder=vocoder,
551
+ tokenizer=tokenizer,
552
+ feature_extractor=feature_extractor,
553
+ num_step=num_step,
554
+ guidance_scale=guidance_scale,
555
+ speed=speed,
556
+ t_shift=t_shift,
557
+ target_rms=target_rms,
558
+ feat_scale=feat_scale,
559
+ sampling_rate=sampling_rate,
560
+ )
561
+ print(f"[Sentence: {i}] RTF: {metrics['rtf']:.4f}")
562
+ total_t.append(metrics["t"])
563
+ total_t_no_vocoder.append(metrics["t_no_vocoder"])
564
+ total_t_vocoder.append(metrics["t_vocoder"])
565
+ total_wav_seconds.append(metrics["wav_seconds"])
566
+
567
+ print(f"Average RTF: {np.sum(total_t) / np.sum(total_wav_seconds):.4f}")
568
+ print(
569
+ f"Average RTF w/o vocoder: "
570
+ f"{np.sum(total_t_no_vocoder) / np.sum(total_wav_seconds):.4f}"
571
+ )
572
+ print(
573
+ f"Average RTF vocoder: "
574
+ f"{np.sum(total_t_vocoder) / np.sum(total_wav_seconds):.4f}"
575
+ )
576
+
577
+
578
+ @torch.inference_mode()
579
+ def main():
580
+ parser = get_parser()
581
+ args = parser.parse_args()
582
+
583
+ params = AttributeDict()
584
+ params.update(vars(args))
585
+ fix_random_seed(params.seed)
586
+
587
+ model_defaults = {
588
+ "zipvoice": {
589
+ "num_step": 16,
590
+ "guidance_scale": 1.0,
591
+ },
592
+ "zipvoice_distill": {
593
+ "num_step": 8,
594
+ "guidance_scale": 3.0,
595
+ },
596
+ }
597
+
598
+ model_specific_defaults = model_defaults.get(params.model_name, {})
599
+
600
+ for param, value in model_specific_defaults.items():
601
+ if getattr(params, param) is None:
602
+ setattr(params, param, value)
603
+ print(f"Setting {param} to default value: {value}")
604
+
605
+ assert (params.test_list is not None) ^ (
606
+ (params.prompt_wav and params.prompt_text and params.text) is not None
607
+ ), (
608
+ "For inference, please provide prompts and text with either '--test-list'"
609
+ " or '--prompt-wav, --prompt-text and --text'."
610
+ )
611
+
612
+ print("Loading model...")
613
+ if params.model_config is None:
614
+ model_config = hf_hub_download(
615
+ HUGGINGFACE_REPO, filename=MODEL_CONFIG[params.model_name]
616
+ )
617
+ else:
618
+ model_config = params.model_config
619
+
620
+ with open(model_config, "r") as f:
621
+ model_config = json.load(f)
622
+
623
+ if params.token_file is None:
624
+ token_file = hf_hub_download(
625
+ HUGGINGFACE_REPO, filename=TOKEN_FILE[params.model_name]
626
+ )
627
+ else:
628
+ token_file = params.token_file
629
+
630
+ if params.tokenizer == "emilia":
631
+ tokenizer = EmiliaTokenizer(token_file=token_file)
632
+ elif params.dataset == "libritts":
633
+ tokenizer = LibriTTSTokenizer(token_file=token_file)
634
+ elif params.tokenizer == "espeak":
635
+ tokenizer = EspeakTokenizer(token_file=token_file, lang=params.lang)
636
+ else:
637
+ assert params.tokenizer == "simple"
638
+ tokenizer = SimpleTokenizer(token_file=token_file)
639
+
640
+ if params.onnx_model_dir is not None:
641
+ dirname = params.onnx_model_dir
642
+ else:
643
+ if params.model_name == "zipvoice_distill":
644
+ dirname = "zipvoice_distill"
645
+ else:
646
+ dirname = "zipvoice"
647
+
648
+ if not params.onnx_int8:
649
+ text_encoder_path = f"{dirname}/text_encoder.onnx"
650
+ fm_decoder_path = f"{dirname}/fm_decoder.onnx"
651
+ else:
652
+ text_encoder_path = f"{dirname}/text_encoder_int8.onnx"
653
+ fm_decoder_path = f"{dirname}/fm_decoder_int8.onnx"
654
+ if params.onnx_model_dir is None:
655
+ text_encoder_path = hf_hub_download(
656
+ HUGGINGFACE_REPO, filename=text_encoder_path
657
+ )
658
+ fm_decoder_path = hf_hub_download(HUGGINGFACE_REPO, filename=fm_decoder_path)
659
+
660
+ model = OnnxModel(text_encoder_path, fm_decoder_path)
661
+
662
+ vocoder = get_vocoder(params.vocoder_path)
663
+ vocoder.eval()
664
+
665
+ if model_config["feature"]["type"] == "vocos":
666
+ feature_extractor = VocosFbank()
667
+ else:
668
+ raise NotImplementedError(
669
+ f"Unsupported feature type: {model_config['feature']['type']}"
670
+ )
671
+ params.sampling_rate = model_config["feature"]["sampling_rate"]
672
+
673
+ print("Start generating...")
674
+ if params.test_list:
675
+ os.makedirs(params.res_dir, exist_ok=True)
676
+ generate_list(
677
+ res_dir=params.res_dir,
678
+ test_list=params.test_list,
679
+ model=model,
680
+ vocoder=vocoder,
681
+ tokenizer=tokenizer,
682
+ feature_extractor=feature_extractor,
683
+ num_step=params.num_step,
684
+ guidance_scale=params.guidance_scale,
685
+ speed=params.speed,
686
+ t_shift=params.t_shift,
687
+ target_rms=params.target_rms,
688
+ feat_scale=params.feat_scale,
689
+ sampling_rate=params.sampling_rate,
690
+ )
691
+ else:
692
+ generate_sentence(
693
+ save_path=params.res_wav_path,
694
+ prompt_text=params.prompt_text,
695
+ prompt_wav=params.prompt_wav,
696
+ text=params.text,
697
+ model=model,
698
+ vocoder=vocoder,
699
+ tokenizer=tokenizer,
700
+ feature_extractor=feature_extractor,
701
+ num_step=params.num_step,
702
+ guidance_scale=params.guidance_scale,
703
+ speed=params.speed,
704
+ t_shift=params.t_shift,
705
+ target_rms=params.target_rms,
706
+ feat_scale=params.feat_scale,
707
+ sampling_rate=params.sampling_rate,
708
+ )
709
+ print("Done")
710
+
711
+
712
+ if __name__ == "__main__":
713
+ torch.set_num_threads(1)
714
+ torch.set_num_interop_threads(1)
715
+ main()
zipvoice/bin/onnx_export.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2025 Xiaomi Corp. (authors: Zengwei Yao)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ This script exports a pre-trained ZipVoice or ZipVoice-Distill model from PyTorch to
20
+ ONNX.
21
+
22
+ Usage:
23
+
24
+ python3 -m zipvoice.bin.onnx_export \
25
+ --model-name zipvoice \
26
+ --token-file data/tokens_emilia.txt \
27
+ --checkpoint exp/zipvoice/epoch-11-avg-4.pt \
28
+ --model-config conf/zipvoice_base.json \
29
+ --onnx-model-dir exp/zipvoice_onnx
30
+
31
+ `--model-name` can be `zipvoice` or `zipvoice_distill`,
32
+ which are the models before and after distillation, respectively.
33
+ """
34
+
35
+
36
+ import argparse
37
+ import json
38
+ import os
39
+ from typing import Dict
40
+
41
+ import onnx
42
+ import safetensors.torch
43
+ import torch
44
+ from onnxruntime.quantization import QuantType, quantize_dynamic
45
+ from torch import Tensor, nn
46
+
47
+ from zipvoice.models.zipvoice import ZipVoice
48
+ from zipvoice.models.zipvoice_distill import ZipVoiceDistill
49
+ from zipvoice.tokenizer.tokenizer import SimpleTokenizer
50
+ from zipvoice.utils.checkpoint import load_checkpoint
51
+ from zipvoice.utils.common import AttributeDict
52
+ from zipvoice.utils.scaling_converter import convert_scaled_to_non_scaled
53
+
54
+
55
+ def get_parser():
56
+ parser = argparse.ArgumentParser(
57
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
58
+ )
59
+
60
+ parser.add_argument(
61
+ "--onnx-model-dir",
62
+ type=str,
63
+ default="exp",
64
+ help="Dir to the exported models",
65
+ )
66
+
67
+ parser.add_argument(
68
+ "--model-name",
69
+ type=str,
70
+ default="zipvoice",
71
+ choices=["zipvoice", "zipvoice_distill"],
72
+ help="The model used for inference",
73
+ )
74
+
75
+ parser.add_argument(
76
+ "--token-file",
77
+ type=str,
78
+ default="data/tokens_emilia.txt",
79
+ help="The file that contains information that maps tokens to ids,"
80
+ "which is a text file with '{token}\t{token_id}' per line.",
81
+ )
82
+
83
+ parser.add_argument(
84
+ "--checkpoint",
85
+ type=str,
86
+ default="exp_zipvoice/epoch-11-avg-4.pt",
87
+ help="The model checkpoint.",
88
+ )
89
+
90
+ parser.add_argument(
91
+ "--model-config",
92
+ type=str,
93
+ default="conf/zipvoice_base.json",
94
+ help="The model configuration file.",
95
+ )
96
+
97
+ return parser
98
+
99
+
100
+ def add_meta_data(filename: str, meta_data: Dict[str, str]):
101
+ """Add meta data to an ONNX model. It is changed in-place.
102
+
103
+ Args:
104
+ filename:
105
+ Filename of the ONNX model to be changed.
106
+ meta_data:
107
+ Key-value pairs.
108
+ """
109
+ model = onnx.load(filename)
110
+ for key, value in meta_data.items():
111
+ meta = model.metadata_props.add()
112
+ meta.key = key
113
+ meta.value = value
114
+
115
+ onnx.save(model, filename)
116
+
117
+
118
+ class OnnxTextModel(nn.Module):
119
+ def __init__(self, model: nn.Module):
120
+ """A wrapper for ZipVoice text encoder."""
121
+ super().__init__()
122
+ self.embed = model.embed
123
+ self.text_encoder = model.text_encoder
124
+ self.pad_id = model.pad_id
125
+
126
+ def forward(
127
+ self,
128
+ tokens: Tensor,
129
+ prompt_tokens: Tensor,
130
+ prompt_features_len: Tensor,
131
+ speed: Tensor,
132
+ ) -> Tensor:
133
+ cat_tokens = torch.cat([prompt_tokens, tokens], dim=1)
134
+ cat_tokens = nn.functional.pad(cat_tokens, (0, 1), value=self.pad_id)
135
+ tokens_len = cat_tokens.shape[1] - 1
136
+ padding_mask = (torch.arange(tokens_len + 1) == tokens_len).unsqueeze(0)
137
+
138
+ embed = self.embed(cat_tokens)
139
+ embed = self.text_encoder(x=embed, t=None, padding_mask=padding_mask)
140
+
141
+ features_len = torch.ceil(
142
+ (prompt_features_len / prompt_tokens.shape[1] * tokens_len / speed)
143
+ ).to(dtype=torch.int64)
144
+
145
+ token_dur = torch.div(features_len, tokens_len, rounding_mode="floor").to(
146
+ dtype=torch.int64
147
+ )
148
+
149
+ text_condition = embed[:, :-1, :].unsqueeze(2).expand(-1, -1, token_dur, -1)
150
+ text_condition = text_condition.reshape(embed.shape[0], -1, embed.shape[2])
151
+
152
+ text_condition = torch.cat(
153
+ [
154
+ text_condition,
155
+ embed[:, -1:, :].expand(-1, features_len - text_condition.shape[1], -1),
156
+ ],
157
+ dim=1,
158
+ )
159
+
160
+ return text_condition
161
+
162
+
163
+ class OnnxFlowMatchingModel(nn.Module):
164
+ def __init__(self, model: nn.Module):
165
+ """A wrapper for ZipVoice flow-matching decoder."""
166
+ super().__init__()
167
+ self.distill = model.distill
168
+ self.fm_decoder = model.fm_decoder
169
+ self.model_func = getattr(model, "forward_fm_decoder")
170
+ self.feat_dim = model.feat_dim
171
+
172
+ def forward(
173
+ self,
174
+ t: Tensor,
175
+ x: Tensor,
176
+ text_condition: Tensor,
177
+ speech_condition: torch.Tensor,
178
+ guidance_scale: Tensor,
179
+ ) -> Tensor:
180
+ if self.distill:
181
+ return self.model_func(
182
+ t=t,
183
+ xt=x,
184
+ text_condition=text_condition,
185
+ speech_condition=speech_condition,
186
+ guidance_scale=guidance_scale,
187
+ )
188
+ else:
189
+ x = x.repeat(2, 1, 1)
190
+ text_condition = torch.cat(
191
+ [torch.zeros_like(text_condition), text_condition], dim=0
192
+ )
193
+ speech_condition = torch.cat(
194
+ [
195
+ torch.where(
196
+ t > 0.5, torch.zeros_like(speech_condition), speech_condition
197
+ ),
198
+ speech_condition,
199
+ ],
200
+ dim=0,
201
+ )
202
+ guidance_scale = torch.where(t > 0.5, guidance_scale, guidance_scale * 2.0)
203
+ data_uncond, data_cond = self.model_func(
204
+ t=t,
205
+ xt=x,
206
+ text_condition=text_condition,
207
+ speech_condition=speech_condition,
208
+ ).chunk(2, dim=0)
209
+ v = (1 + guidance_scale) * data_cond - guidance_scale * data_uncond
210
+ return v
211
+
212
+
213
+ def export_text_encoder(
214
+ model: OnnxTextModel,
215
+ filename: str,
216
+ opset_version: int = 11,
217
+ ) -> None:
218
+ """Export the text encoder model to ONNX format.
219
+
220
+ Args:
221
+ model:
222
+ The input model
223
+ filename:
224
+ The filename to save the exported ONNX model.
225
+ opset_version:
226
+ The opset version to use.
227
+ """
228
+ tokens = torch.tensor([[2, 3, 4, 5]], dtype=torch.int64)
229
+ prompt_tokens = torch.tensor([[0, 1]], dtype=torch.int64)
230
+ prompt_features_len = torch.tensor(10, dtype=torch.int64)
231
+ speed = torch.tensor(1.0, dtype=torch.float32)
232
+
233
+ model = torch.jit.trace(model, (tokens, prompt_tokens, prompt_features_len, speed))
234
+
235
+ torch.onnx.export(
236
+ model,
237
+ (tokens, prompt_tokens, prompt_features_len, speed),
238
+ filename,
239
+ verbose=False,
240
+ opset_version=opset_version,
241
+ input_names=["tokens", "prompt_tokens", "prompt_features_len", "speed"],
242
+ output_names=["text_condition"],
243
+ dynamic_axes={
244
+ "tokens": {0: "N", 1: "T"},
245
+ "prompt_tokens": {0: "N", 1: "T"},
246
+ "text_condition": {0: "N", 1: "T"},
247
+ },
248
+ )
249
+
250
+ meta_data = {
251
+ "version": "1",
252
+ "model_author": "k2-fsa",
253
+ "comment": "ZipVoice text encoder",
254
+ }
255
+ print(f"meta_data: {meta_data}")
256
+ add_meta_data(filename=filename, meta_data=meta_data)
257
+
258
+ print(f"Exported to {filename}")
259
+
260
+
261
+ def export_fm_decoder(
262
+ model: OnnxFlowMatchingModel,
263
+ filename: str,
264
+ opset_version: int = 11,
265
+ ) -> None:
266
+ """Export the flow matching decoder model to ONNX format.
267
+
268
+ Args:
269
+ model:
270
+ The input model
271
+ filename:
272
+ The filename to save the exported ONNX model.
273
+ opset_version:
274
+ The opset version to use.
275
+ """
276
+ feat_dim = model.feat_dim
277
+ seq_len = 200
278
+ t = torch.tensor(0.5, dtype=torch.float32)
279
+ x = torch.randn(1, seq_len, feat_dim, dtype=torch.float32)
280
+ text_condition = torch.randn(1, seq_len, feat_dim, dtype=torch.float32)
281
+ speech_condition = torch.randn(1, seq_len, feat_dim, dtype=torch.float32)
282
+ guidance_scale = torch.tensor(1.0, dtype=torch.float32)
283
+
284
+ model = torch.jit.trace(
285
+ model, (t, x, text_condition, speech_condition, guidance_scale)
286
+ )
287
+
288
+ torch.onnx.export(
289
+ model,
290
+ (t, x, text_condition, speech_condition, guidance_scale),
291
+ filename,
292
+ verbose=False,
293
+ opset_version=opset_version,
294
+ input_names=["t", "x", "text_condition", "speech_condition", "guidance_scale"],
295
+ output_names=["v"],
296
+ dynamic_axes={
297
+ "x": {0: "N", 1: "T"},
298
+ "text_condition": {0: "N", 1: "T"},
299
+ "speech_condition": {0: "N", 1: "T"},
300
+ "v": {0: "N", 1: "T"},
301
+ },
302
+ )
303
+
304
+ meta_data = {
305
+ "version": "1",
306
+ "model_author": "k2-fsa",
307
+ "comment": "ZipVoice flow-matching decoder",
308
+ "feat_dim": str(feat_dim),
309
+ }
310
+ print(f"meta_data: {meta_data}")
311
+ add_meta_data(filename=filename, meta_data=meta_data)
312
+
313
+ print(f"Exported to {filename}")
314
+
315
+
316
+ @torch.no_grad()
317
+ def main():
318
+ parser = get_parser()
319
+ args = parser.parse_args()
320
+
321
+ params = AttributeDict()
322
+ params.update(vars(args))
323
+
324
+ model_config = params.model_config
325
+ with open(model_config, "r") as f:
326
+ model_config = json.load(f)
327
+ for key, value in model_config["model"].items():
328
+ setattr(params, key, value)
329
+ for key, value in model_config["feature"].items():
330
+ setattr(params, key, value)
331
+
332
+ token_file = params.token_file
333
+ tokenizer = SimpleTokenizer(token_file)
334
+ tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id}
335
+
336
+ if params.model_name == "zipvoice":
337
+ model = ZipVoice(
338
+ **model_config["model"],
339
+ **tokenizer_config,
340
+ )
341
+ else:
342
+ assert params.model_name == "zipvoice_distill"
343
+ model = ZipVoiceDistill(
344
+ **model_config["model"],
345
+ **tokenizer_config,
346
+ )
347
+ model_ckpt = params.checkpoint
348
+
349
+ if model_ckpt.endswith(".safetensors"):
350
+ safetensors.torch.load_model(model, model_ckpt)
351
+ elif model_ckpt.endswith(".pt"):
352
+ load_checkpoint(filename=model_ckpt, model=model, strict=True)
353
+ else:
354
+ raise NotImplementedError(f"Unsupported model checkpoint format: {model_ckpt}")
355
+
356
+ device = torch.device("cpu")
357
+ model = model.to(device)
358
+ model.eval()
359
+
360
+ convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
361
+
362
+ print("Exporting model")
363
+ os.makedirs(params.onnx_model_dir, exist_ok=True)
364
+ opset_version = 11
365
+
366
+ text_encoder = OnnxTextModel(model=model)
367
+ text_encoder_file = f"{params.onnx_model_dir}/text_encoder.onnx"
368
+ export_text_encoder(
369
+ model=text_encoder,
370
+ filename=text_encoder_file,
371
+ opset_version=opset_version,
372
+ )
373
+
374
+ fm_decoder = OnnxFlowMatchingModel(model=model)
375
+ fm_decoder_file = f"{params.onnx_model_dir}/fm_decoder.onnx"
376
+ export_fm_decoder(
377
+ model=fm_decoder,
378
+ filename=fm_decoder_file,
379
+ opset_version=opset_version,
380
+ )
381
+
382
+ print("Generate int8 quantization models")
383
+
384
+ text_encoder_int8_file = f"{params.onnx_model_dir}/text_encoder_int8.onnx"
385
+ quantize_dynamic(
386
+ model_input=text_encoder_file,
387
+ model_output=text_encoder_int8_file,
388
+ op_types_to_quantize=["MatMul"],
389
+ weight_type=QuantType.QInt8,
390
+ )
391
+
392
+ fm_decoder_int8_file = f"{params.onnx_model_dir}/fm_decoder_int8.onnx"
393
+ quantize_dynamic(
394
+ model_input=fm_decoder_file,
395
+ model_output=fm_decoder_int8_file,
396
+ op_types_to_quantize=["MatMul"],
397
+ weight_type=QuantType.QInt8,
398
+ )
399
+
400
+ print("Done!")
401
+
402
+
403
+ if __name__ == "__main__":
404
+ main()
zipvoice/bin/train_zipvoice.py ADDED
@@ -0,0 +1,1110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2024-2025 Xiaomi Corp. (authors: Wei Kang,
3
+ # Han Zhu)
4
+ #
5
+ # See ../../../../LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ """
20
+ This script trains a ZipVoice model with the flow-matching loss.
21
+
22
+ Usage:
23
+
24
+ python3 -m zipvoice.bin.train_zipvoice \
25
+ --world-size 8 \
26
+ --use-fp16 1 \
27
+ --num-epochs 11 \
28
+ --max-duration 500 \
29
+ --lr-hours 30000 \
30
+ --model-config conf/zipvoice_base.json \
31
+ --tokenizer emilia \
32
+ --token-file "data/tokens_emilia.txt" \
33
+ --dataset emilia \
34
+ --manifest-dir data/fbank \
35
+ --exp-dir exp/zipvoice
36
+ """
37
+
38
+ import argparse
39
+ import copy
40
+ import json
41
+ import logging
42
+ import os
43
+ from functools import partial
44
+ from pathlib import Path
45
+ from shutil import copyfile
46
+ from typing import List, Optional, Tuple, Union
47
+
48
+ import torch
49
+ import torch.multiprocessing as mp
50
+ import torch.nn as nn
51
+ from lhotse.cut import Cut, CutSet
52
+ from lhotse.utils import fix_random_seed
53
+ from torch import Tensor
54
+ from torch.amp import GradScaler, autocast
55
+ from torch.nn.parallel import DistributedDataParallel as DDP
56
+ from torch.optim import Optimizer
57
+ from torch.utils.tensorboard import SummaryWriter
58
+
59
+ import zipvoice.utils.diagnostics as diagnostics
60
+ from zipvoice.dataset.datamodule import TtsDataModule
61
+ from zipvoice.models.zipvoice import ZipVoice
62
+ from zipvoice.tokenizer.tokenizer import (
63
+ EmiliaTokenizer,
64
+ EspeakTokenizer,
65
+ LibriTTSTokenizer,
66
+ SimpleTokenizer,
67
+ )
68
+ from zipvoice.utils.checkpoint import (
69
+ load_checkpoint,
70
+ remove_checkpoints,
71
+ resume_checkpoint,
72
+ save_checkpoint,
73
+ save_checkpoint_with_global_batch_idx,
74
+ update_averaged_model,
75
+ )
76
+ from zipvoice.utils.common import (
77
+ AttributeDict,
78
+ MetricsTracker,
79
+ cleanup_dist,
80
+ get_adjusted_batch_count,
81
+ get_env_info,
82
+ get_parameter_groups_with_lrs,
83
+ prepare_input,
84
+ set_batch_count,
85
+ setup_dist,
86
+ setup_logger,
87
+ str2bool,
88
+ )
89
+ from zipvoice.utils.hooks import register_inf_check_hooks
90
+ from zipvoice.utils.lr_scheduler import Eden, FixedLRScheduler, LRScheduler
91
+ from zipvoice.utils.optim import ScaledAdam
92
+
93
+ LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, LRScheduler]
94
+
95
+
96
+ def get_parser():
97
+ parser = argparse.ArgumentParser(
98
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
99
+ )
100
+
101
+ parser.add_argument(
102
+ "--world-size",
103
+ type=int,
104
+ default=1,
105
+ help="Number of GPUs for DDP training.",
106
+ )
107
+
108
+ parser.add_argument(
109
+ "--master-port",
110
+ type=int,
111
+ default=12356,
112
+ help="Master port to use for DDP training.",
113
+ )
114
+
115
+ parser.add_argument(
116
+ "--tensorboard",
117
+ type=str2bool,
118
+ default=True,
119
+ help="Should various information be logged in tensorboard.",
120
+ )
121
+
122
+ parser.add_argument(
123
+ "--num-epochs",
124
+ type=int,
125
+ default=11,
126
+ help="Number of epochs to train.",
127
+ )
128
+
129
+ parser.add_argument(
130
+ "--num-iters",
131
+ type=int,
132
+ default=0,
133
+ help="Number of iter to train, will ignore num_epochs if > 0.",
134
+ )
135
+
136
+ parser.add_argument(
137
+ "--start-epoch",
138
+ type=int,
139
+ default=1,
140
+ help="""Resume training from this epoch. It should be positive.
141
+ If larger than 1, it will load checkpoint from
142
+ exp-dir/epoch-{start_epoch-1}.pt
143
+ """,
144
+ )
145
+
146
+ parser.add_argument(
147
+ "--checkpoint",
148
+ type=str,
149
+ default=None,
150
+ help="""Checkpoints of pre-trained models, will load it if not None
151
+ """,
152
+ )
153
+
154
+ parser.add_argument(
155
+ "--exp-dir",
156
+ type=str,
157
+ default="exp/zipvoice",
158
+ help="""The experiment dir.
159
+ It specifies the directory where all training related
160
+ files, e.g., checkpoints, log, etc, are saved
161
+ """,
162
+ )
163
+
164
+ parser.add_argument(
165
+ "--base-lr", type=float, default=0.02, help="The base learning rate."
166
+ )
167
+
168
+ parser.add_argument(
169
+ "--lr-batches",
170
+ type=float,
171
+ default=7500,
172
+ help="""Number of steps that affects how rapidly the learning rate
173
+ decreases. We suggest not to change this.""",
174
+ )
175
+
176
+ parser.add_argument(
177
+ "--lr-epochs",
178
+ type=float,
179
+ default=10,
180
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
181
+ """,
182
+ )
183
+
184
+ parser.add_argument(
185
+ "--lr-hours",
186
+ type=float,
187
+ default=0,
188
+ help="""If positive, --epoch is ignored and it specifies the number of hours
189
+ that affects how rapidly the learning rate decreases.
190
+ """,
191
+ )
192
+
193
+ parser.add_argument(
194
+ "--ref-duration",
195
+ type=float,
196
+ default=50,
197
+ help="""Reference batch duration for purposes of adjusting batch counts for"
198
+ setting various schedules inside the model".
199
+ """,
200
+ )
201
+
202
+ parser.add_argument(
203
+ "--finetune",
204
+ type=str2bool,
205
+ default=False,
206
+ help="Whether to use the fine-tuning mode, will used a fixed learning rate "
207
+ "schedule and skip the large dropout phase.",
208
+ )
209
+
210
+ parser.add_argument(
211
+ "--seed",
212
+ type=int,
213
+ default=42,
214
+ help="The seed for random generators intended for reproducibility",
215
+ )
216
+
217
+ parser.add_argument(
218
+ "--print-diagnostics",
219
+ type=str2bool,
220
+ default=False,
221
+ help="Accumulate stats on activations, print them and exit.",
222
+ )
223
+
224
+ parser.add_argument(
225
+ "--scan-oom",
226
+ type=str2bool,
227
+ default=False,
228
+ help="Scan pessimistic batches to see whether they cause OOMs.",
229
+ )
230
+
231
+ parser.add_argument(
232
+ "--inf-check",
233
+ type=str2bool,
234
+ default=False,
235
+ help="Add hooks to check for infinite module outputs and gradients.",
236
+ )
237
+
238
+ parser.add_argument(
239
+ "--save-every-n",
240
+ type=int,
241
+ default=5000,
242
+ help="""Save checkpoint after processing this number of batches"
243
+ periodically. We save checkpoint to exp-dir/ whenever
244
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
245
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
246
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
247
+ end of each epoch where `xxx` is the epoch number counting from 1.
248
+ """,
249
+ )
250
+
251
+ parser.add_argument(
252
+ "--keep-last-k",
253
+ type=int,
254
+ default=30,
255
+ help="""Only keep this number of checkpoints on disk.
256
+ For instance, if it is 3, there are only 3 checkpoints
257
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
258
+ It does not affect checkpoints with name `epoch-xxx.pt`.
259
+ """,
260
+ )
261
+
262
+ parser.add_argument(
263
+ "--average-period",
264
+ type=int,
265
+ default=200,
266
+ help="""Update the averaged model, namely `model_avg`, after processing
267
+ this number of batches. `model_avg` is a separate version of model,
268
+ in which each floating-point parameter is the average of all the
269
+ parameters from the start of training. Each time we take the average,
270
+ we do: `model_avg = model * (average_period / batch_idx_train) +
271
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
272
+ """,
273
+ )
274
+
275
+ parser.add_argument(
276
+ "--use-fp16",
277
+ type=str2bool,
278
+ default=True,
279
+ help="Whether to use half precision training.",
280
+ )
281
+
282
+ parser.add_argument(
283
+ "--feat-scale",
284
+ type=float,
285
+ default=0.1,
286
+ help="The scale factor of fbank feature",
287
+ )
288
+
289
+ parser.add_argument(
290
+ "--condition-drop-ratio",
291
+ type=float,
292
+ default=0.2,
293
+ help="The drop rate of text condition during training.",
294
+ )
295
+
296
+ parser.add_argument(
297
+ "--dataset",
298
+ type=str,
299
+ default="emilia",
300
+ choices=["emilia", "libritts", "custom"],
301
+ help="The used training dataset",
302
+ )
303
+
304
+ parser.add_argument(
305
+ "--train-manifest",
306
+ type=str,
307
+ help="Path of the training manifest",
308
+ )
309
+
310
+ parser.add_argument(
311
+ "--dev-manifest",
312
+ type=str,
313
+ help="Path of the validation manifest",
314
+ )
315
+
316
+ parser.add_argument(
317
+ "--min-len",
318
+ type=float,
319
+ default=1.0,
320
+ help="The minimum audio length used for training",
321
+ )
322
+
323
+ parser.add_argument(
324
+ "--max-len",
325
+ type=float,
326
+ default=30.0,
327
+ help="The maximum audio length used for training",
328
+ )
329
+
330
+ parser.add_argument(
331
+ "--model-config",
332
+ type=str,
333
+ default="conf/zipvoice_base.json",
334
+ help="The model configuration file.",
335
+ )
336
+
337
+ parser.add_argument(
338
+ "--tokenizer",
339
+ type=str,
340
+ default="emilia",
341
+ choices=["emilia", "libritts", "espeak", "simple"],
342
+ help="Tokenizer type.",
343
+ )
344
+
345
+ parser.add_argument(
346
+ "--lang",
347
+ type=str,
348
+ default="en-us",
349
+ help="Language identifier, used when tokenizer type is espeak. see"
350
+ "https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md",
351
+ )
352
+
353
+ parser.add_argument(
354
+ "--token-file",
355
+ type=str,
356
+ default="data/tokens_emilia.txt",
357
+ help="The file that contains information that maps tokens to ids,"
358
+ "which is a text file with '{token}\t{token_id}' per line.",
359
+ )
360
+
361
+ return parser
362
+
363
+
364
+ def get_params() -> AttributeDict:
365
+ """Return a dict containing training parameters.
366
+
367
+ All training related parameters that are not passed from the commandline
368
+ are saved in the variable `params`.
369
+
370
+ Commandline options are merged into `params` after they are parsed, so
371
+ you can also access them via `params`.
372
+
373
+ Explanation of options saved in `params`:
374
+
375
+ - best_train_loss: Best training loss so far. It is used to select
376
+ the model that has the lowest training loss. It is
377
+ updated during the training.
378
+
379
+ - best_valid_loss: Best validation loss so far. It is used to select
380
+ the model that has the lowest validation loss. It is
381
+ updated during the training.
382
+
383
+ - best_train_epoch: It is the epoch that has the best training loss.
384
+
385
+ - best_valid_epoch: It is the epoch that has the best validation loss.
386
+
387
+ - batch_idx_train: Used to writing statistics to tensorboard. It
388
+ contains number of batches trained so far across
389
+ epochs.
390
+
391
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
392
+
393
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
394
+
395
+ - env_info: A dict containing information about the environment.
396
+
397
+ """
398
+ params = AttributeDict(
399
+ {
400
+ "best_train_loss": float("inf"),
401
+ "best_valid_loss": float("inf"),
402
+ "best_train_epoch": -1,
403
+ "best_valid_epoch": -1,
404
+ "batch_idx_train": 0,
405
+ "log_interval": 50,
406
+ "reset_interval": 200,
407
+ "env_info": get_env_info(),
408
+ }
409
+ )
410
+
411
+ return params
412
+
413
+
414
+ def compute_fbank_loss(
415
+ params: AttributeDict,
416
+ model: Union[nn.Module, DDP],
417
+ features: Tensor,
418
+ features_lens: Tensor,
419
+ tokens: List[List[int]],
420
+ is_training: bool,
421
+ ) -> Tuple[Tensor, MetricsTracker]:
422
+ """
423
+ Compute loss given the model and its inputs.
424
+
425
+ Args:
426
+ params:
427
+ Parameters for training. See :func:`get_params`.
428
+ model:
429
+ The model for training.
430
+ features:
431
+ The target acoustic feature.
432
+ features_lens:
433
+ The number of frames of each utterance.
434
+ tokens:
435
+ Input tokens that representing the transcripts.
436
+ is_training:
437
+ True for training. False for validation. When it is True, this
438
+ function enables autograd during computation; when it is False, it
439
+ disables autograd.
440
+ """
441
+
442
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
443
+
444
+ batch_size, num_frames, _ = features.shape
445
+
446
+ features = torch.nn.functional.pad(
447
+ features, (0, 0, 0, num_frames - features.size(1))
448
+ ) # (B, T, F)
449
+ noise = torch.randn_like(features) # (B, T, F)
450
+
451
+ # Sampling t from uniform distribution
452
+ if is_training:
453
+ t = torch.rand(batch_size, 1, 1, device=device)
454
+ else:
455
+ t = (
456
+ (torch.arange(batch_size, device=device) / batch_size)
457
+ .unsqueeze(1)
458
+ .unsqueeze(2)
459
+ )
460
+ with torch.set_grad_enabled(is_training):
461
+
462
+ loss = model(
463
+ tokens=tokens,
464
+ features=features,
465
+ features_lens=features_lens,
466
+ noise=noise,
467
+ t=t,
468
+ condition_drop_ratio=params.condition_drop_ratio,
469
+ )
470
+
471
+ assert loss.requires_grad == is_training
472
+ info = MetricsTracker()
473
+ num_frames = features_lens.sum().item()
474
+ info["frames"] = num_frames
475
+ info["loss"] = loss.detach().cpu().item() * num_frames
476
+
477
+ return loss, info
478
+
479
+
480
+ def train_one_epoch(
481
+ params: AttributeDict,
482
+ model: Union[nn.Module, DDP],
483
+ optimizer: Optimizer,
484
+ scheduler: LRSchedulerType,
485
+ train_dl: torch.utils.data.DataLoader,
486
+ valid_dl: torch.utils.data.DataLoader,
487
+ scaler: GradScaler,
488
+ model_avg: Optional[nn.Module] = None,
489
+ tb_writer: Optional[SummaryWriter] = None,
490
+ world_size: int = 1,
491
+ rank: int = 0,
492
+ ) -> None:
493
+ """Train the model for one epoch.
494
+
495
+ The training loss from the mean of all frames is saved in
496
+ `params.train_loss`. It runs the validation process every
497
+ `params.valid_interval` batches.
498
+
499
+ Args:
500
+ params:
501
+ It is returned by :func:`get_params`.
502
+ model:
503
+ The model for training.
504
+ optimizer:
505
+ The optimizer.
506
+ scheduler:
507
+ The learning rate scheduler, we call step() every epoch.
508
+ train_dl:
509
+ Dataloader for the training dataset.
510
+ valid_dl:
511
+ Dataloader for the validation dataset.
512
+ scaler:
513
+ The scaler used for mix precision training.
514
+ tb_writer:
515
+ Writer to write log messages to tensorboard.
516
+ world_size:
517
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
518
+ rank:
519
+ The rank of the node in DDP training. If no DDP is used, it should
520
+ be set to 0.
521
+ """
522
+ model.train()
523
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
524
+
525
+ # used to track the stats over iterations in one epoch
526
+ tot_loss = MetricsTracker()
527
+
528
+ saved_bad_model = False
529
+
530
+ def save_bad_model(suffix: str = ""):
531
+ save_checkpoint(
532
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
533
+ model=model,
534
+ model_avg=model_avg,
535
+ params=params,
536
+ optimizer=optimizer,
537
+ scheduler=scheduler,
538
+ sampler=train_dl.sampler,
539
+ scaler=scaler,
540
+ rank=0,
541
+ )
542
+
543
+ for batch_idx, batch in enumerate(train_dl):
544
+
545
+ if batch_idx % 10 == 0:
546
+ if params.finetune:
547
+ set_batch_count(model, get_adjusted_batch_count(params) + 100000)
548
+ else:
549
+ set_batch_count(model, get_adjusted_batch_count(params))
550
+
551
+ if (
552
+ params.batch_idx_train > 0
553
+ and params.batch_idx_train % params.valid_interval == 0
554
+ and not params.print_diagnostics
555
+ ):
556
+ logging.info("Computing validation loss")
557
+ valid_info = compute_validation_loss(
558
+ params=params,
559
+ model=model,
560
+ valid_dl=valid_dl,
561
+ world_size=world_size,
562
+ )
563
+ model.train()
564
+ logging.info(
565
+ f"Epoch {params.cur_epoch}, global_batch_idx: {params.batch_idx_train},"
566
+ f" validation: {valid_info}"
567
+ )
568
+ logging.info(
569
+ f"Maximum memory allocated so far is "
570
+ f"{torch.cuda.max_memory_allocated() // 1000000}MB"
571
+ )
572
+ if tb_writer is not None:
573
+ valid_info.write_summary(
574
+ tb_writer, "train/valid_", params.batch_idx_train
575
+ )
576
+
577
+ params.batch_idx_train += 1
578
+
579
+ batch_size = len(batch["text"])
580
+
581
+ tokens, features, features_lens = prepare_input(
582
+ params=params,
583
+ batch=batch,
584
+ device=device,
585
+ return_tokens=True,
586
+ return_feature=True,
587
+ )
588
+
589
+ try:
590
+ with autocast("cuda", enabled=params.use_fp16):
591
+ loss, loss_info = compute_fbank_loss(
592
+ params=params,
593
+ model=model,
594
+ features=features,
595
+ features_lens=features_lens,
596
+ tokens=tokens,
597
+ is_training=True,
598
+ )
599
+
600
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
601
+
602
+ scaler.scale(loss).backward()
603
+
604
+ scheduler.step_batch(params.batch_idx_train)
605
+ # Use the number of hours of speech to adjust the learning rate
606
+ if params.lr_hours > 0:
607
+ scheduler.step_epoch(
608
+ params.batch_idx_train
609
+ * params.max_duration
610
+ * params.world_size
611
+ / 3600
612
+ )
613
+ scaler.step(optimizer)
614
+ scaler.update()
615
+ optimizer.zero_grad()
616
+ except Exception as e:
617
+ logging.info(f"Caught exception : {e}.")
618
+ save_bad_model()
619
+ raise
620
+
621
+ if params.print_diagnostics and batch_idx == 5:
622
+ return
623
+
624
+ if (
625
+ rank == 0
626
+ and params.batch_idx_train > 0
627
+ and params.batch_idx_train % params.average_period == 0
628
+ ):
629
+ update_averaged_model(
630
+ params=params,
631
+ model_cur=model,
632
+ model_avg=model_avg,
633
+ )
634
+
635
+ if (
636
+ params.batch_idx_train > 0
637
+ and params.batch_idx_train % params.save_every_n == 0
638
+ ):
639
+ save_checkpoint_with_global_batch_idx(
640
+ out_dir=params.exp_dir,
641
+ global_batch_idx=params.batch_idx_train,
642
+ model=model,
643
+ model_avg=model_avg,
644
+ params=params,
645
+ optimizer=optimizer,
646
+ scheduler=scheduler,
647
+ sampler=train_dl.sampler,
648
+ scaler=scaler,
649
+ rank=rank,
650
+ )
651
+ remove_checkpoints(
652
+ out_dir=params.exp_dir,
653
+ topk=params.keep_last_k,
654
+ rank=rank,
655
+ )
656
+ if params.num_iters > 0 and params.batch_idx_train > params.num_iters:
657
+ break
658
+ if params.batch_idx_train % 100 == 0 and params.use_fp16:
659
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
660
+ # of the grad scaler is configurable, but we can't configure it to have
661
+ # different behavior depending on the current grad scale.
662
+ cur_grad_scale = scaler._scale.item()
663
+
664
+ if cur_grad_scale < 1024.0 or (
665
+ cur_grad_scale < 4096.0 and params.batch_idx_train % 400 == 0
666
+ ):
667
+ scaler.update(cur_grad_scale * 2.0)
668
+ if cur_grad_scale < 0.01:
669
+ if not saved_bad_model:
670
+ save_bad_model(suffix="-first-warning")
671
+ saved_bad_model = True
672
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
673
+ if cur_grad_scale < 1.0e-05:
674
+ save_bad_model()
675
+ raise RuntimeError(
676
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
677
+ )
678
+
679
+ if params.batch_idx_train % params.log_interval == 0:
680
+ cur_lr = max(scheduler.get_last_lr())
681
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
682
+
683
+ logging.info(
684
+ f"Epoch {params.cur_epoch}, batch {batch_idx}, "
685
+ f"global_batch_idx: {params.batch_idx_train}, "
686
+ f"batch size: {batch_size}, "
687
+ f"loss[{loss_info}], tot_loss[{tot_loss}], "
688
+ f"cur_lr: {cur_lr:.2e}, "
689
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
690
+ )
691
+
692
+ if tb_writer is not None:
693
+ tb_writer.add_scalar(
694
+ "train/learning_rate", cur_lr, params.batch_idx_train
695
+ )
696
+ loss_info.write_summary(
697
+ tb_writer, "train/current_", params.batch_idx_train
698
+ )
699
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
700
+ if params.use_fp16:
701
+ tb_writer.add_scalar(
702
+ "train/grad_scale",
703
+ cur_grad_scale,
704
+ params.batch_idx_train,
705
+ )
706
+
707
+ loss_value = tot_loss["loss"]
708
+ params.train_loss = loss_value
709
+ if params.train_loss < params.best_train_loss:
710
+ params.best_train_epoch = params.cur_epoch
711
+ params.best_train_loss = params.train_loss
712
+
713
+
714
+ def compute_validation_loss(
715
+ params: AttributeDict,
716
+ model: Union[nn.Module, DDP],
717
+ valid_dl: torch.utils.data.DataLoader,
718
+ world_size: int = 1,
719
+ ) -> MetricsTracker:
720
+ """Run the validation process."""
721
+
722
+ model.eval()
723
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
724
+
725
+ # used to summary the stats over iterations
726
+ tot_loss = MetricsTracker()
727
+
728
+ for batch_idx, batch in enumerate(valid_dl):
729
+ tokens, features, features_lens = prepare_input(
730
+ params=params,
731
+ batch=batch,
732
+ device=device,
733
+ return_tokens=True,
734
+ return_feature=True,
735
+ )
736
+
737
+ loss, loss_info = compute_fbank_loss(
738
+ params=params,
739
+ model=model,
740
+ features=features,
741
+ features_lens=features_lens,
742
+ tokens=tokens,
743
+ is_training=False,
744
+ )
745
+ assert loss.requires_grad is False
746
+ tot_loss = tot_loss + loss_info
747
+
748
+ if world_size > 1:
749
+ tot_loss.reduce(loss.device)
750
+
751
+ loss_value = tot_loss["loss"]
752
+ if loss_value < params.best_valid_loss:
753
+ params.best_valid_epoch = params.cur_epoch
754
+ params.best_valid_loss = loss_value
755
+
756
+ return tot_loss
757
+
758
+
759
+ def display_and_save_batch(
760
+ batch: dict,
761
+ params: AttributeDict,
762
+ ) -> None:
763
+ """Display the batch statistics and save the batch into disk.
764
+
765
+ Args:
766
+ batch:
767
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
768
+ for the content in it.
769
+ params:
770
+ Parameters for training. See :func:`get_params`.
771
+ sp:
772
+ The BPE model.
773
+ """
774
+ from lhotse.utils import uuid4
775
+
776
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
777
+ logging.info(f"Saving batch to {filename}")
778
+ torch.save(batch, filename)
779
+
780
+ features = batch["features"]
781
+ tokens = batch["tokens"]
782
+
783
+ logging.info(f"features shape: {features.shape}")
784
+ num_tokens = sum(len(i) for i in tokens)
785
+ logging.info(f"num tokens: {num_tokens}")
786
+
787
+
788
+ def scan_pessimistic_batches_for_oom(
789
+ model: Union[nn.Module, DDP],
790
+ train_dl: torch.utils.data.DataLoader,
791
+ optimizer: torch.optim.Optimizer,
792
+ params: AttributeDict,
793
+ ):
794
+ from lhotse.dataset import find_pessimistic_batches
795
+
796
+ logging.info(
797
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
798
+ )
799
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
800
+
801
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
802
+ for criterion, cuts in batches.items():
803
+ batch = train_dl.dataset[cuts]
804
+ tokens, features, features_lens = prepare_input(
805
+ params=params,
806
+ batch=batch,
807
+ device=device,
808
+ return_tokens=True,
809
+ return_feature=True,
810
+ )
811
+ try:
812
+ with autocast("cuda", enabled=params.use_fp16):
813
+
814
+ loss, loss_info = compute_fbank_loss(
815
+ params=params,
816
+ model=model,
817
+ features=features,
818
+ features_lens=features_lens,
819
+ tokens=tokens,
820
+ is_training=True,
821
+ )
822
+ loss.backward()
823
+ optimizer.zero_grad()
824
+ except Exception as e:
825
+ if "CUDA out of memory" in str(e):
826
+ logging.error(
827
+ "Your GPU ran out of memory with the current "
828
+ "max_duration setting. We recommend decreasing "
829
+ "max_duration and trying again.\n"
830
+ f"Failing criterion: {criterion} "
831
+ f"(={crit_values[criterion]}) ..."
832
+ )
833
+ display_and_save_batch(batch, params=params)
834
+ raise
835
+ logging.info(
836
+ f"Maximum memory allocated so far is "
837
+ f"{torch.cuda.max_memory_allocated() // 1000000}MB"
838
+ )
839
+
840
+
841
+ def tokenize_text(c: Cut, tokenizer):
842
+ text = c.supervisions[0].text
843
+ tokens = tokenizer.texts_to_token_ids([text])
844
+ c.supervisions[0].tokens = tokens[0]
845
+ return c
846
+
847
+
848
+ def run(rank, world_size, args):
849
+ """
850
+ Args:
851
+ rank:
852
+ It is a value between 0 and `world_size-1`, which is
853
+ passed automatically by `mp.spawn()` in :func:`main`.
854
+ The node with rank 0 is responsible for saving checkpoint.
855
+ world_size:
856
+ Number of GPUs for DDP training.
857
+ args:
858
+ The return value of get_parser().parse_args()
859
+ """
860
+ params = get_params()
861
+ params.update(vars(args))
862
+ params.valid_interval = params.save_every_n
863
+ # Set epoch to a large number to ignore it.
864
+ if params.num_iters > 0:
865
+ params.num_epochs = 1000000
866
+ with open(params.model_config, "r") as f:
867
+ model_config = json.load(f)
868
+ params.update(model_config["model"])
869
+ params.update(model_config["feature"])
870
+
871
+ fix_random_seed(params.seed)
872
+ if world_size > 1:
873
+ setup_dist(rank, world_size, params.master_port)
874
+
875
+ os.makedirs(f"{params.exp_dir}", exist_ok=True)
876
+ copyfile(src=params.model_config, dst=f"{params.exp_dir}/model.json")
877
+ copyfile(src=params.token_file, dst=f"{params.exp_dir}/tokens.txt")
878
+ setup_logger(f"{params.exp_dir}/log/log-train")
879
+
880
+ if args.tensorboard and rank == 0:
881
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
882
+ else:
883
+ tb_writer = None
884
+
885
+ if torch.cuda.is_available():
886
+ params.device = torch.device("cuda", rank)
887
+ else:
888
+ params.device = torch.device("cpu")
889
+ logging.info(f"Device: {params.device}")
890
+
891
+ if params.tokenizer == "emilia":
892
+ tokenizer = EmiliaTokenizer(token_file=params.token_file)
893
+ elif params.tokenizer == "libritts":
894
+ tokenizer = LibriTTSTokenizer(token_file=params.token_file)
895
+ elif params.tokenizer == "espeak":
896
+ tokenizer = EspeakTokenizer(token_file=params.token_file, lang=params.lang)
897
+ else:
898
+ assert params.tokenizer == "simple"
899
+ tokenizer = SimpleTokenizer(token_file=params.token_file)
900
+
901
+ tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id}
902
+ params.update(tokenizer_config)
903
+
904
+ logging.info(params)
905
+
906
+ logging.info("About to create model")
907
+
908
+ model = ZipVoice(
909
+ **model_config["model"],
910
+ **tokenizer_config,
911
+ )
912
+
913
+ if params.checkpoint is not None:
914
+ logging.info(f"Loading pre-trained model from {params.checkpoint}")
915
+ _ = load_checkpoint(filename=params.checkpoint, model=model, strict=True)
916
+ num_param = sum([p.numel() for p in model.parameters()])
917
+ logging.info(f"Number of parameters : {num_param}")
918
+
919
+ model_avg: Optional[nn.Module] = None
920
+ if rank == 0:
921
+ # model_avg is only used with rank 0
922
+ model_avg = copy.deepcopy(model).to(torch.float64)
923
+
924
+ assert params.start_epoch > 0, params.start_epoch
925
+ if params.start_epoch > 1:
926
+ checkpoints = resume_checkpoint(params=params, model=model, model_avg=model_avg)
927
+
928
+ model = model.to(params.device)
929
+ if world_size > 1:
930
+ logging.info("Using DDP")
931
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
932
+
933
+ optimizer = ScaledAdam(
934
+ get_parameter_groups_with_lrs(
935
+ model,
936
+ lr=params.base_lr,
937
+ include_names=True,
938
+ ),
939
+ lr=params.base_lr, # should have no effect
940
+ clipping_scale=2.0,
941
+ )
942
+
943
+ assert params.lr_hours >= 0
944
+
945
+ if params.finetune:
946
+ scheduler = FixedLRScheduler(optimizer)
947
+ elif params.lr_hours > 0:
948
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_hours)
949
+ else:
950
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
951
+
952
+ scaler = GradScaler("cuda", enabled=params.use_fp16)
953
+
954
+ if params.start_epoch > 1 and checkpoints is not None:
955
+ # load state_dict for optimizers
956
+ if "optimizer" in checkpoints:
957
+ logging.info("Loading optimizer state dict")
958
+ optimizer.load_state_dict(checkpoints["optimizer"])
959
+
960
+ # load state_dict for schedulers
961
+ if "scheduler" in checkpoints:
962
+ logging.info("Loading scheduler state dict")
963
+ scheduler.load_state_dict(checkpoints["scheduler"])
964
+
965
+ if "grad_scaler" in checkpoints:
966
+ logging.info("Loading grad scaler state dict")
967
+ scaler.load_state_dict(checkpoints["grad_scaler"])
968
+
969
+ if params.print_diagnostics:
970
+ opts = diagnostics.TensorDiagnosticOptions(
971
+ 512
972
+ ) # allow 4 megabytes per sub-module
973
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
974
+
975
+ if params.inf_check:
976
+ register_inf_check_hooks(model)
977
+
978
+ def remove_short_and_long_utt(c: Cut, min_len: float, max_len: float):
979
+ if c.duration < min_len or c.duration > max_len:
980
+ return False
981
+ return True
982
+
983
+ _remove_short_and_long_utt = partial(
984
+ remove_short_and_long_utt, min_len=params.min_len, max_len=params.max_len
985
+ )
986
+
987
+ datamodule = TtsDataModule(args)
988
+ if params.dataset == "emilia":
989
+ train_cuts = CutSet.mux(
990
+ datamodule.train_emilia_EN_cuts(),
991
+ datamodule.train_emilia_ZH_cuts(),
992
+ weights=[46000, 49000],
993
+ )
994
+ train_cuts = train_cuts.filter(_remove_short_and_long_utt)
995
+ dev_cuts = CutSet.mux(
996
+ datamodule.dev_emilia_EN_cuts(),
997
+ datamodule.dev_emilia_ZH_cuts(),
998
+ weights=[0.5, 0.5],
999
+ )
1000
+ elif params.dataset == "libritts":
1001
+ train_cuts = datamodule.train_libritts_cuts()
1002
+ train_cuts = train_cuts.filter(_remove_short_and_long_utt)
1003
+ dev_cuts = datamodule.dev_libritts_cuts()
1004
+ else:
1005
+ assert params.dataset == "custom"
1006
+ train_cuts = datamodule.train_custom_cuts(params.train_manifest)
1007
+ train_cuts = train_cuts.filter(_remove_short_and_long_utt)
1008
+ dev_cuts = datamodule.dev_custom_cuts(params.dev_manifest)
1009
+ # To avoid OOM issues due to too long dev cuts
1010
+ dev_cuts = dev_cuts.filter(_remove_short_and_long_utt)
1011
+
1012
+ _tokenize_text = partial(tokenize_text, tokenizer=tokenizer)
1013
+ train_cuts = train_cuts.map(_tokenize_text)
1014
+ dev_cuts = dev_cuts.map(_tokenize_text)
1015
+
1016
+ train_dl = datamodule.train_dataloaders(train_cuts)
1017
+
1018
+ valid_dl = datamodule.dev_dataloaders(dev_cuts)
1019
+
1020
+ if params.scan_oom:
1021
+ scan_pessimistic_batches_for_oom(
1022
+ model=model,
1023
+ train_dl=train_dl,
1024
+ optimizer=optimizer,
1025
+ params=params,
1026
+ )
1027
+
1028
+ logging.info("Training started")
1029
+
1030
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
1031
+ logging.info(f"Start epoch {epoch}")
1032
+
1033
+ if params.lr_hours == 0:
1034
+ scheduler.step_epoch(epoch - 1)
1035
+ fix_random_seed(params.seed + epoch - 1)
1036
+ train_dl.sampler.set_epoch(epoch - 1)
1037
+
1038
+ params.cur_epoch = epoch
1039
+
1040
+ if tb_writer is not None:
1041
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
1042
+
1043
+ train_one_epoch(
1044
+ params=params,
1045
+ model=model,
1046
+ model_avg=model_avg,
1047
+ optimizer=optimizer,
1048
+ scheduler=scheduler,
1049
+ train_dl=train_dl,
1050
+ valid_dl=valid_dl,
1051
+ scaler=scaler,
1052
+ tb_writer=tb_writer,
1053
+ world_size=world_size,
1054
+ rank=rank,
1055
+ )
1056
+
1057
+ if params.num_iters > 0 and params.batch_idx_train > params.num_iters:
1058
+ break
1059
+
1060
+ if params.print_diagnostics:
1061
+ diagnostic.print_diagnostics()
1062
+ break
1063
+
1064
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
1065
+ save_checkpoint(
1066
+ filename=filename,
1067
+ params=params,
1068
+ model=model,
1069
+ model_avg=model_avg,
1070
+ optimizer=optimizer,
1071
+ scheduler=scheduler,
1072
+ sampler=train_dl.sampler,
1073
+ scaler=scaler,
1074
+ rank=rank,
1075
+ )
1076
+
1077
+ if rank == 0:
1078
+ if params.best_train_epoch == params.cur_epoch:
1079
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
1080
+ copyfile(src=filename, dst=best_train_filename)
1081
+
1082
+ if params.best_valid_epoch == params.cur_epoch:
1083
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
1084
+ copyfile(src=filename, dst=best_valid_filename)
1085
+
1086
+ logging.info("Done!")
1087
+
1088
+ if world_size > 1:
1089
+ torch.distributed.barrier()
1090
+ cleanup_dist()
1091
+
1092
+
1093
+ def main():
1094
+ parser = get_parser()
1095
+ TtsDataModule.add_arguments(parser)
1096
+ args = parser.parse_args()
1097
+ args.exp_dir = Path(args.exp_dir)
1098
+
1099
+ world_size = args.world_size
1100
+ assert world_size >= 1
1101
+ if world_size > 1:
1102
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
1103
+ else:
1104
+ run(rank=0, world_size=1, args=args)
1105
+
1106
+
1107
+ if __name__ == "__main__":
1108
+ torch.set_num_threads(1)
1109
+ torch.set_num_interop_threads(1)
1110
+ main()
zipvoice/bin/train_zipvoice_distill.py ADDED
@@ -0,0 +1,1159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2024 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+
19
+ """
20
+ This script trains a ZipVoice-Distill model starting from a ZipVoice model.
21
+ It has two distillation stages.
22
+
23
+ Usage:
24
+
25
+ (1) The first distillation stage with a fixed ZipVoice model as the teacher.
26
+
27
+ python3 -m zipvoice.bin.train_zipvoice_distill \
28
+ --world-size 8 \
29
+ --use-fp16 1 \
30
+ --num-iters 60000 \
31
+ --max-duration 500 \
32
+ --base-lr 0.0005 \
33
+ --tokenizer emilia \
34
+ --token-file data/tokens_emilia.txt \
35
+ --dataset emilia \
36
+ --manifest-dir data/fbank \
37
+ --teacher-model zipvoice/exp_zipvoice/epoch-11-avg-4.pt \
38
+ --distill-stage first \
39
+ --exp-dir exp/zipvoice_distill_1stage
40
+
41
+ (2) The second distillation stage with a EMA model as the teacher.
42
+ python3 -m zipvoice.bin.train_zipvoice_distill \
43
+ --world-size 8 \
44
+ --use-fp16 1 \
45
+ --num-iters 2000 \
46
+ --save-every-n 1000 \
47
+ --max-duration 500 \
48
+ --base-lr 0.0001 \
49
+ --model-config conf/zipvoice_base.json \
50
+ --tokenizer emilia \
51
+ --token-file data/tokens_emilia.txt \
52
+ --dataset emilia \
53
+ --manifest-dir data/fbank \
54
+ --teacher-model zipvoice/exp_zipvoice_distill_1stage/iter-60000-avg-7.pt \
55
+ --distill-stage second \
56
+ --exp-dir zipvoice/exp_zipvoice_distill
57
+ """
58
+
59
+ import argparse
60
+ import copy
61
+ import json
62
+ import logging
63
+ import os
64
+ import random
65
+ from functools import partial
66
+ from pathlib import Path
67
+ from shutil import copyfile
68
+ from typing import List, Optional, Tuple, Union
69
+
70
+ import torch
71
+ import torch.multiprocessing as mp
72
+ import torch.nn as nn
73
+ from lhotse.cut import Cut, CutSet
74
+ from lhotse.utils import fix_random_seed
75
+ from torch import Tensor
76
+ from torch.amp import GradScaler, autocast
77
+ from torch.nn.parallel import DistributedDataParallel as DDP
78
+ from torch.optim import Optimizer
79
+ from torch.utils.tensorboard import SummaryWriter
80
+
81
+ import zipvoice.utils.diagnostics as diagnostics
82
+ from zipvoice.bin.train_zipvoice import (
83
+ display_and_save_batch,
84
+ get_params,
85
+ tokenize_text,
86
+ )
87
+ from zipvoice.dataset.datamodule import TtsDataModule
88
+ from zipvoice.models.zipvoice import ZipVoice
89
+ from zipvoice.models.zipvoice_distill import ZipVoiceDistill
90
+ from zipvoice.tokenizer.tokenizer import (
91
+ EmiliaTokenizer,
92
+ EspeakTokenizer,
93
+ LibriTTSTokenizer,
94
+ SimpleTokenizer,
95
+ )
96
+ from zipvoice.utils.checkpoint import (
97
+ load_checkpoint,
98
+ remove_checkpoints,
99
+ resume_checkpoint,
100
+ save_checkpoint,
101
+ save_checkpoint_with_global_batch_idx,
102
+ update_averaged_model,
103
+ )
104
+ from zipvoice.utils.common import (
105
+ AttributeDict,
106
+ MetricsTracker,
107
+ cleanup_dist,
108
+ condition_time_mask,
109
+ get_adjusted_batch_count,
110
+ get_parameter_groups_with_lrs,
111
+ make_pad_mask,
112
+ prepare_input,
113
+ set_batch_count,
114
+ setup_dist,
115
+ setup_logger,
116
+ str2bool,
117
+ )
118
+ from zipvoice.utils.hooks import register_inf_check_hooks
119
+ from zipvoice.utils.lr_scheduler import FixedLRScheduler, LRScheduler
120
+ from zipvoice.utils.optim import ScaledAdam
121
+
122
+ LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, LRScheduler]
123
+
124
+
125
+ def get_parser():
126
+ parser = argparse.ArgumentParser(
127
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
128
+ )
129
+
130
+ parser.add_argument(
131
+ "--world-size",
132
+ type=int,
133
+ default=1,
134
+ help="Number of GPUs for DDP training.",
135
+ )
136
+
137
+ parser.add_argument(
138
+ "--master-port",
139
+ type=int,
140
+ default=12356,
141
+ help="Master port to use for DDP training.",
142
+ )
143
+
144
+ parser.add_argument(
145
+ "--tensorboard",
146
+ type=str2bool,
147
+ default=True,
148
+ help="Should various information be logged in tensorboard.",
149
+ )
150
+
151
+ parser.add_argument(
152
+ "--num-epochs",
153
+ type=int,
154
+ default=1,
155
+ help="Number of epochs to train.",
156
+ )
157
+
158
+ parser.add_argument(
159
+ "--num-iters",
160
+ type=int,
161
+ default=0,
162
+ help="Number of iter to train, will ignore num_epochs if > 0.",
163
+ )
164
+
165
+ parser.add_argument(
166
+ "--start-epoch",
167
+ type=int,
168
+ default=1,
169
+ help="""Resume training from this epoch. It should be positive.
170
+ If larger than 1, it will load checkpoint from
171
+ exp-dir/epoch-{start_epoch-1}.pt
172
+ """,
173
+ )
174
+
175
+ parser.add_argument(
176
+ "--teacher-model",
177
+ type=str,
178
+ help="""Checkpoints of pre-trained teacher model""",
179
+ )
180
+
181
+ parser.add_argument(
182
+ "--exp-dir",
183
+ type=str,
184
+ default="exp/zipvoice_distill",
185
+ help="""The experiment dir.
186
+ It specifies the directory where all training related
187
+ files, e.g., checkpoints, log, etc, are saved
188
+ """,
189
+ )
190
+
191
+ parser.add_argument(
192
+ "--base-lr", type=float, default=0.001, help="The base learning rate."
193
+ )
194
+
195
+ parser.add_argument(
196
+ "--ref-duration",
197
+ type=float,
198
+ default=50,
199
+ help="Reference batch duration for purposes of adjusting batch counts for "
200
+ "setting various schedules inside the model",
201
+ )
202
+
203
+ parser.add_argument(
204
+ "--seed",
205
+ type=int,
206
+ default=42,
207
+ help="The seed for random generators intended for reproducibility",
208
+ )
209
+
210
+ parser.add_argument(
211
+ "--print-diagnostics",
212
+ type=str2bool,
213
+ default=False,
214
+ help="Accumulate stats on activations, print them and exit.",
215
+ )
216
+
217
+ parser.add_argument(
218
+ "--scan-oom",
219
+ type=str2bool,
220
+ default=False,
221
+ help="Scan pessimistic batches to see whether they cause OOMs.",
222
+ )
223
+
224
+ parser.add_argument(
225
+ "--inf-check",
226
+ type=str2bool,
227
+ default=False,
228
+ help="Add hooks to check for infinite module outputs and gradients.",
229
+ )
230
+
231
+ parser.add_argument(
232
+ "--save-every-n",
233
+ type=int,
234
+ default=1000,
235
+ help="""Save checkpoint after processing this number of batches"
236
+ periodically. We save checkpoint to exp-dir/ whenever
237
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
238
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
239
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
240
+ end of each epoch where `xxx` is the epoch number counting from 1.
241
+ """,
242
+ )
243
+
244
+ parser.add_argument(
245
+ "--keep-last-k",
246
+ type=int,
247
+ default=30,
248
+ help="""Only keep this number of checkpoints on disk.
249
+ For instance, if it is 3, there are only 3 checkpoints
250
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
251
+ It does not affect checkpoints with name `epoch-xxx.pt`.
252
+ """,
253
+ )
254
+
255
+ parser.add_argument(
256
+ "--average-period",
257
+ type=int,
258
+ default=200,
259
+ help="""Update the averaged model, namely `model_avg`, after processing
260
+ this number of batches. `model_avg` is a separate version of model,
261
+ in which each floating-point parameter is the average of all the
262
+ parameters from the start of training. Each time we take the average,
263
+ we do: `model_avg = model * (average_period / batch_idx_train) +
264
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
265
+ """,
266
+ )
267
+
268
+ parser.add_argument(
269
+ "--use-fp16",
270
+ type=str2bool,
271
+ default=True,
272
+ help="Whether to use half precision training.",
273
+ )
274
+
275
+ parser.add_argument(
276
+ "--feat-scale",
277
+ type=float,
278
+ default=0.1,
279
+ help="The scale factor of fbank feature",
280
+ )
281
+
282
+ parser.add_argument(
283
+ "--ema-decay",
284
+ type=float,
285
+ default=0.9999,
286
+ help="The EMA decay factor of target model in distillation.",
287
+ )
288
+ parser.add_argument(
289
+ "--distill-stage",
290
+ type=str,
291
+ choices=["first", "second"],
292
+ help="The stage of distillation.",
293
+ )
294
+
295
+ parser.add_argument(
296
+ "--dataset",
297
+ type=str,
298
+ default="emilia",
299
+ choices=["emilia", "libritts", "custom"],
300
+ help="The used training dataset",
301
+ )
302
+
303
+ parser.add_argument(
304
+ "--train-manifest",
305
+ type=str,
306
+ help="Path of the training manifest",
307
+ )
308
+
309
+ parser.add_argument(
310
+ "--dev-manifest",
311
+ type=str,
312
+ help="Path of the validation manifest",
313
+ )
314
+
315
+ parser.add_argument(
316
+ "--min-len",
317
+ type=float,
318
+ default=1.0,
319
+ help="The minimum audio length used for training",
320
+ )
321
+
322
+ parser.add_argument(
323
+ "--max-len",
324
+ type=float,
325
+ default=30.0,
326
+ help="The maximum audio length used for training",
327
+ )
328
+
329
+ parser.add_argument(
330
+ "--model-config",
331
+ type=str,
332
+ default="conf/zipvoice_base.json",
333
+ help="The model configuration file.",
334
+ )
335
+
336
+ parser.add_argument(
337
+ "--tokenizer",
338
+ type=str,
339
+ default="emilia",
340
+ choices=["emilia", "libritts", "espeak", "simple"],
341
+ help="Tokenizer type.",
342
+ )
343
+
344
+ parser.add_argument(
345
+ "--lang",
346
+ type=str,
347
+ default="en-us",
348
+ help="Language identifier, used when tokenizer type is espeak. see"
349
+ "https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md",
350
+ )
351
+
352
+ parser.add_argument(
353
+ "--lang",
354
+ type=str,
355
+ default="en-us",
356
+ help="Language identifier, used when tokenizer type is espeak. see"
357
+ "https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md",
358
+ )
359
+
360
+ parser.add_argument(
361
+ "--token-file",
362
+ type=str,
363
+ default="data/tokens_emilia.txt",
364
+ help="The file that contains information that maps tokens to ids,"
365
+ "which is a text file with '{token}\t{token_id}' per line.",
366
+ )
367
+
368
+ return parser
369
+
370
+
371
+ def ema(new_model, ema_model, decay):
372
+ if isinstance(new_model, DDP):
373
+ new_model = new_model.module
374
+ if isinstance(ema_model, DDP):
375
+ ema_model = ema_model.module
376
+ new_model_dict = new_model.state_dict()
377
+ ema_model_dict = ema_model.state_dict()
378
+ for key in new_model_dict.keys():
379
+ ema_model_dict[key].data.copy_(
380
+ ema_model_dict[key].data * decay + new_model_dict[key].data * (1 - decay)
381
+ )
382
+
383
+
384
+ def compute_fbank_loss(
385
+ params: AttributeDict,
386
+ model: Union[nn.Module, DDP],
387
+ teacher_model: Union[nn.Module, DDP],
388
+ features: Tensor,
389
+ features_lens: Tensor,
390
+ tokens: List[List[int]],
391
+ is_training: bool,
392
+ ) -> Tuple[Tensor, MetricsTracker]:
393
+ """
394
+ Compute loss given the model and its inputs.
395
+
396
+ Args:
397
+ params:
398
+ Parameters for training. See :func:`get_params`.
399
+ model:
400
+ The model for training.
401
+ teacher_model:
402
+ The teacher model for distillation.
403
+ features:
404
+ The target acoustic feature.
405
+ features_lens:
406
+ The number of frames of each utterance.
407
+ tokens:
408
+ Input tokens that representing the transcripts.
409
+ is_training:
410
+ True for training. False for validation. When it is True, this
411
+ function enables autograd during computation; when it is False, it
412
+ disables autograd.
413
+ """
414
+
415
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
416
+
417
+ batch_size, num_frames, _ = features.shape
418
+
419
+ features = torch.nn.functional.pad(
420
+ features, (0, 0, 0, num_frames - features.size(1))
421
+ ) # (B, T, F)
422
+ noise = torch.randn_like(features) # (B, T, F)
423
+
424
+ # Sampling t and guidance_scale from uniform distribution
425
+
426
+ t_value = random.random()
427
+ t = torch.ones(batch_size, 1, 1, device=device) * t_value
428
+ if params.distill_stage == "first":
429
+ guidance_scale = torch.rand(batch_size, 1, 1, device=device) * 2
430
+ else:
431
+ guidance_scale = torch.rand(batch_size, 1, 1, device=device) * 2 + 1
432
+ xt = features * t + noise * (1 - t)
433
+ t_delta_fix = random.uniform(0.0, min(0.3, 1 - t_value))
434
+ t_delta_ema = random.uniform(0.0, min(0.3, 1 - t_value - t_delta_fix))
435
+ t_dest = t_value + t_delta_fix + t_delta_ema
436
+
437
+ with torch.no_grad():
438
+ speech_condition_mask = condition_time_mask(
439
+ features_lens=features_lens,
440
+ mask_percent=(0.7, 1.0),
441
+ max_len=features.size(1),
442
+ )
443
+
444
+ if params.distill_stage == "first":
445
+ teacher_x_t_mid, _ = teacher_model.sample_intermediate(
446
+ tokens=tokens,
447
+ features=features,
448
+ features_lens=features_lens,
449
+ noise=xt,
450
+ speech_condition_mask=speech_condition_mask,
451
+ t_start=t_value,
452
+ t_end=t_value + t_delta_fix,
453
+ num_step=1,
454
+ guidance_scale=guidance_scale,
455
+ )
456
+
457
+ target_x1, _ = teacher_model.sample_intermediate(
458
+ tokens=tokens,
459
+ features=features,
460
+ features_lens=features_lens,
461
+ noise=teacher_x_t_mid,
462
+ speech_condition_mask=speech_condition_mask,
463
+ t_start=t_value + t_delta_fix,
464
+ t_end=t_dest,
465
+ num_step=1,
466
+ guidance_scale=guidance_scale,
467
+ )
468
+ else:
469
+ teacher_x_t_mid, _ = teacher_model(
470
+ tokens=tokens,
471
+ features=features,
472
+ features_lens=features_lens,
473
+ noise=xt,
474
+ speech_condition_mask=speech_condition_mask,
475
+ t_start=t_value,
476
+ t_end=t_value + t_delta_fix,
477
+ num_step=1,
478
+ guidance_scale=guidance_scale,
479
+ )
480
+
481
+ target_x1, _ = teacher_model(
482
+ tokens=tokens,
483
+ features=features,
484
+ features_lens=features_lens,
485
+ noise=teacher_x_t_mid,
486
+ speech_condition_mask=speech_condition_mask,
487
+ t_start=t_value + t_delta_fix,
488
+ t_end=t_dest,
489
+ num_step=1,
490
+ guidance_scale=guidance_scale,
491
+ )
492
+
493
+ with torch.set_grad_enabled(is_training):
494
+
495
+ pred_x1, _ = model(
496
+ tokens=tokens,
497
+ features=features,
498
+ features_lens=features_lens,
499
+ noise=xt,
500
+ speech_condition_mask=speech_condition_mask,
501
+ t_start=t,
502
+ t_end=t_dest,
503
+ num_step=1,
504
+ guidance_scale=guidance_scale,
505
+ )
506
+ pred_v = (pred_x1 - xt) / (t_dest - t)
507
+
508
+ padding_mask = make_pad_mask(features_lens, max_len=num_frames) # (B, T)
509
+ loss_mask = speech_condition_mask & (~padding_mask)
510
+
511
+ target_v = (target_x1 - xt) / (t_dest - t)
512
+ loss = torch.mean((pred_v[loss_mask] - target_v[loss_mask]) ** 2)
513
+
514
+ ut = features - noise # (B, T, F)
515
+
516
+ ref_loss = torch.mean((pred_v[loss_mask] - ut[loss_mask]) ** 2)
517
+
518
+ assert loss.requires_grad == is_training
519
+ info = MetricsTracker()
520
+ num_frames = features_lens.sum().item()
521
+ info["frames"] = num_frames
522
+ info["loss"] = loss.detach().cpu().item() * num_frames
523
+ info["ref_loss"] = ref_loss.detach().cpu().item() * num_frames
524
+ return loss, info
525
+
526
+
527
+ def train_one_epoch(
528
+ params: AttributeDict,
529
+ model: Union[nn.Module, DDP],
530
+ teacher_model: Union[nn.Module, DDP],
531
+ optimizer: Optimizer,
532
+ scheduler: LRSchedulerType,
533
+ train_dl: torch.utils.data.DataLoader,
534
+ valid_dl: torch.utils.data.DataLoader,
535
+ scaler: GradScaler,
536
+ model_avg: Optional[nn.Module] = None,
537
+ tb_writer: Optional[SummaryWriter] = None,
538
+ world_size: int = 1,
539
+ rank: int = 0,
540
+ ) -> None:
541
+ """Train the model for one epoch.
542
+
543
+ The training loss from the mean of all frames is saved in
544
+ `params.train_loss`. It runs the validation process every
545
+ `params.valid_interval` batches.
546
+
547
+ Args:
548
+ params:
549
+ It is returned by :func:`get_params`.
550
+ model:
551
+ The model for training.
552
+ teacher_model:
553
+ The model for distillation.
554
+ Used to convert text to tokens.
555
+ optimizer:
556
+ The optimizer.
557
+ scheduler:
558
+ The learning rate scheduler, we call step() every epoch.
559
+ train_dl:
560
+ Dataloader for the training dataset.
561
+ valid_dl:
562
+ Dataloader for the validation dataset.
563
+ scaler:
564
+ The scaler used for mix precision training.
565
+ tb_writer:
566
+ Writer to write log messages to tensorboard.
567
+ world_size:
568
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
569
+ rank:
570
+ The rank of the node in DDP training. If no DDP is used, it should
571
+ be set to 0.
572
+ """
573
+ model.train()
574
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
575
+
576
+ # used to track the stats over iterations in one epoch
577
+ tot_loss = MetricsTracker()
578
+
579
+ saved_bad_model = False
580
+
581
+ def save_bad_model(suffix: str = ""):
582
+ save_checkpoint(
583
+ filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
584
+ model=model,
585
+ model_avg=model_avg,
586
+ model_ema=teacher_model,
587
+ params=params,
588
+ optimizer=optimizer,
589
+ scheduler=scheduler,
590
+ sampler=train_dl.sampler,
591
+ scaler=scaler,
592
+ rank=0,
593
+ )
594
+
595
+ for batch_idx, batch in enumerate(train_dl):
596
+
597
+ if batch_idx % 10 == 0:
598
+ set_batch_count(model, get_adjusted_batch_count(params) + 100000)
599
+
600
+ if (
601
+ params.batch_idx_train % params.valid_interval == 0
602
+ and not params.print_diagnostics
603
+ ):
604
+ logging.info("Computing validation loss")
605
+ valid_info = compute_validation_loss(
606
+ params=params,
607
+ model=model,
608
+ teacher_model=teacher_model,
609
+ valid_dl=valid_dl,
610
+ world_size=world_size,
611
+ )
612
+ model.train()
613
+ logging.info(
614
+ f"Epoch {params.cur_epoch}, global_batch_idx: {params.batch_idx_train},"
615
+ f" validation: {valid_info}"
616
+ )
617
+ logging.info(
618
+ f"Maximum memory allocated so far is "
619
+ f"{torch.cuda.max_memory_allocated() // 1000000}MB"
620
+ )
621
+ if tb_writer is not None:
622
+ valid_info.write_summary(
623
+ tb_writer, "train/valid_", params.batch_idx_train
624
+ )
625
+
626
+ params.batch_idx_train += 1
627
+
628
+ batch_size = len(batch["text"])
629
+
630
+ tokens, features, features_lens = prepare_input(
631
+ params=params,
632
+ batch=batch,
633
+ device=device,
634
+ return_tokens=True,
635
+ return_feature=True,
636
+ )
637
+
638
+ try:
639
+ with autocast("cuda", enabled=params.use_fp16):
640
+ loss, loss_info = compute_fbank_loss(
641
+ params=params,
642
+ model=model,
643
+ teacher_model=teacher_model,
644
+ features=features,
645
+ features_lens=features_lens,
646
+ tokens=tokens,
647
+ is_training=True,
648
+ )
649
+
650
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
651
+
652
+ scaler.scale(loss).backward()
653
+
654
+ scheduler.step_batch(params.batch_idx_train)
655
+ scaler.step(optimizer)
656
+ scaler.update()
657
+ optimizer.zero_grad()
658
+ if params.distill_stage == "second":
659
+ ema(model, teacher_model, params.ema_decay)
660
+ except Exception as e:
661
+ logging.info(f"Caught exception : {e}.")
662
+ save_bad_model()
663
+ raise
664
+
665
+ if params.print_diagnostics and batch_idx == 5:
666
+ return
667
+
668
+ if (
669
+ rank == 0
670
+ and params.batch_idx_train > 0
671
+ and params.batch_idx_train % params.average_period == 0
672
+ ):
673
+ update_averaged_model(
674
+ params=params,
675
+ model_cur=model,
676
+ model_avg=model_avg,
677
+ )
678
+
679
+ if (
680
+ params.batch_idx_train > 0
681
+ and params.batch_idx_train % params.save_every_n == 0
682
+ ):
683
+ save_checkpoint_with_global_batch_idx(
684
+ out_dir=params.exp_dir,
685
+ global_batch_idx=params.batch_idx_train,
686
+ model=model,
687
+ model_avg=model_avg,
688
+ params=params,
689
+ optimizer=optimizer,
690
+ scheduler=scheduler,
691
+ sampler=train_dl.sampler,
692
+ scaler=scaler,
693
+ rank=rank,
694
+ )
695
+ remove_checkpoints(
696
+ out_dir=params.exp_dir,
697
+ topk=params.keep_last_k,
698
+ rank=rank,
699
+ )
700
+ if params.num_iters > 0 and params.batch_idx_train > params.num_iters:
701
+ break
702
+ if params.batch_idx_train % 100 == 0 and params.use_fp16:
703
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
704
+ # of the grad scaler is configurable, but we can't configure it to have
705
+ # different behavior depending on the current grad scale.
706
+ cur_grad_scale = scaler._scale.item()
707
+
708
+ if cur_grad_scale < 1024.0 or (
709
+ cur_grad_scale < 4096.0 and params.batch_idx_train % 400 == 0
710
+ ):
711
+ scaler.update(cur_grad_scale * 2.0)
712
+ if cur_grad_scale < 0.01:
713
+ if not saved_bad_model:
714
+ save_bad_model(suffix="-first-warning")
715
+ saved_bad_model = True
716
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
717
+ if cur_grad_scale < 1.0e-05:
718
+ save_bad_model()
719
+ raise RuntimeError(
720
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
721
+ )
722
+
723
+ if params.batch_idx_train % params.log_interval == 0:
724
+ cur_lr = max(scheduler.get_last_lr())
725
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
726
+
727
+ logging.info(
728
+ f"Epoch {params.cur_epoch}, batch {batch_idx}, "
729
+ f"global_batch_idx: {params.batch_idx_train}, "
730
+ f"batch size: {batch_size}, "
731
+ f"loss[{loss_info}], tot_loss[{tot_loss}], "
732
+ f"cur_lr: {cur_lr:.2e}, "
733
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
734
+ )
735
+
736
+ if tb_writer is not None:
737
+ tb_writer.add_scalar(
738
+ "train/learning_rate", cur_lr, params.batch_idx_train
739
+ )
740
+ loss_info.write_summary(
741
+ tb_writer, "train/current_", params.batch_idx_train
742
+ )
743
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
744
+ if params.use_fp16:
745
+ tb_writer.add_scalar(
746
+ "train/grad_scale",
747
+ cur_grad_scale,
748
+ params.batch_idx_train,
749
+ )
750
+
751
+ loss_value = tot_loss["loss"]
752
+ params.train_loss = loss_value
753
+ if params.train_loss < params.best_train_loss:
754
+ params.best_train_epoch = params.cur_epoch
755
+ params.best_train_loss = params.train_loss
756
+
757
+
758
+ def compute_validation_loss(
759
+ params: AttributeDict,
760
+ model: Union[nn.Module, DDP],
761
+ teacher_model: Optional[nn.Module],
762
+ valid_dl: torch.utils.data.DataLoader,
763
+ world_size: int = 1,
764
+ ) -> MetricsTracker:
765
+ """Run the validation process."""
766
+
767
+ model.eval()
768
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
769
+
770
+ # used to summary the stats over iterations
771
+ tot_loss = MetricsTracker()
772
+
773
+ for batch_idx, batch in enumerate(valid_dl):
774
+ tokens, features, features_lens = prepare_input(
775
+ params=params,
776
+ batch=batch,
777
+ device=device,
778
+ return_tokens=True,
779
+ return_feature=True,
780
+ )
781
+
782
+ loss, loss_info = compute_fbank_loss(
783
+ params=params,
784
+ model=model,
785
+ teacher_model=teacher_model,
786
+ features=features,
787
+ features_lens=features_lens,
788
+ tokens=tokens,
789
+ is_training=False,
790
+ )
791
+ assert loss.requires_grad is False
792
+ tot_loss = tot_loss + loss_info
793
+
794
+ if world_size > 1:
795
+ tot_loss.reduce(loss.device)
796
+
797
+ loss_value = tot_loss["loss"]
798
+ if loss_value < params.best_valid_loss:
799
+ params.best_valid_epoch = params.cur_epoch
800
+ params.best_valid_loss = loss_value
801
+
802
+ return tot_loss
803
+
804
+
805
+ def scan_pessimistic_batches_for_oom(
806
+ model: Union[nn.Module, DDP],
807
+ teacher_model: nn.Module,
808
+ train_dl: torch.utils.data.DataLoader,
809
+ optimizer: torch.optim.Optimizer,
810
+ params: AttributeDict,
811
+ ):
812
+ from lhotse.dataset import find_pessimistic_batches
813
+
814
+ logging.info(
815
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
816
+ )
817
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
818
+
819
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
820
+ for criterion, cuts in batches.items():
821
+ batch = train_dl.dataset[cuts]
822
+ tokens, features, features_lens = prepare_input(
823
+ params=params,
824
+ batch=batch,
825
+ device=device,
826
+ return_tokens=True,
827
+ return_feature=True,
828
+ )
829
+ try:
830
+ with autocast("cuda", enabled=params.use_fp16):
831
+
832
+ loss, loss_info = compute_fbank_loss(
833
+ params=params,
834
+ model=model,
835
+ teacher_model=teacher_model,
836
+ features=features,
837
+ features_lens=features_lens,
838
+ tokens=tokens,
839
+ is_training=True,
840
+ )
841
+ loss.backward()
842
+ optimizer.zero_grad()
843
+ except Exception as e:
844
+ if "CUDA out of memory" in str(e):
845
+ logging.error(
846
+ "Your GPU ran out of memory with the current "
847
+ "max_duration setting. We recommend decreasing "
848
+ "max_duration and trying again.\n"
849
+ f"Failing criterion: {criterion} "
850
+ f"(={crit_values[criterion]}) ..."
851
+ )
852
+ display_and_save_batch(batch, params=params)
853
+ raise
854
+ logging.info(
855
+ f"Maximum memory allocated so far is "
856
+ f"{torch.cuda.max_memory_allocated() // 1000000}MB"
857
+ )
858
+
859
+
860
+ def run(rank, world_size, args):
861
+ """
862
+ Args:
863
+ rank:
864
+ It is a value between 0 and `world_size-1`, which is
865
+ passed automatically by `mp.spawn()` in :func:`main`.
866
+ The node with rank 0 is responsible for saving checkpoint.
867
+ world_size:
868
+ Number of GPUs for DDP training.
869
+ args:
870
+ The return value of get_parser().parse_args()
871
+ """
872
+ params = get_params()
873
+ params.update(vars(args))
874
+ params.valid_interval = params.save_every_n
875
+ # Set epoch to a large number to ignore it.
876
+ if params.num_iters > 0:
877
+ params.num_epochs = 1000000
878
+ with open(params.model_config, "r") as f:
879
+ model_config = json.load(f)
880
+ params.update(model_config["model"])
881
+ params.update(model_config["feature"])
882
+
883
+ fix_random_seed(params.seed)
884
+ if world_size > 1:
885
+ setup_dist(rank, world_size, params.master_port)
886
+
887
+ os.makedirs(f"{params.exp_dir}", exist_ok=True)
888
+ copyfile(src=params.model_config, dst=f"{params.exp_dir}/model.json")
889
+ copyfile(src=params.token_file, dst=f"{params.exp_dir}/tokens.txt")
890
+ setup_logger(f"{params.exp_dir}/log/log-train")
891
+
892
+ if args.tensorboard and rank == 0:
893
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
894
+ else:
895
+ tb_writer = None
896
+
897
+ if torch.cuda.is_available():
898
+ params.device = torch.device("cuda", rank)
899
+ else:
900
+ params.device = torch.device("cpu")
901
+ logging.info(f"Device: {params.device}")
902
+
903
+ if params.tokenizer == "emilia":
904
+ tokenizer = EmiliaTokenizer(token_file=params.token_file)
905
+ elif params.tokenizer == "libritts":
906
+ tokenizer = LibriTTSTokenizer(token_file=params.token_file)
907
+ elif params.tokenizer == "espeak":
908
+ tokenizer = EspeakTokenizer(token_file=params.token_file, lang=params.lang)
909
+ else:
910
+ assert params.tokenizer == "simple"
911
+ tokenizer = SimpleTokenizer(token_file=params.token_file)
912
+
913
+ tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id}
914
+ params.update(tokenizer_config)
915
+
916
+ logging.info(params)
917
+
918
+ logging.info("About to create model")
919
+
920
+ assert params.teacher_model is not None
921
+ logging.info(f"Loading pre-trained model from {params.teacher_model}")
922
+ model = ZipVoiceDistill(
923
+ **model_config["model"],
924
+ **tokenizer_config,
925
+ )
926
+ _ = load_checkpoint(
927
+ filename=params.teacher_model,
928
+ model=model,
929
+ strict=(params.distill_stage == "second"),
930
+ )
931
+
932
+ if params.distill_stage == "first":
933
+ teacher_model = ZipVoice(
934
+ **model_config["model"],
935
+ **tokenizer_config,
936
+ )
937
+ _ = load_checkpoint(
938
+ filename=params.teacher_model, model=teacher_model, strict=True
939
+ )
940
+ else:
941
+ teacher_model = copy.deepcopy(model)
942
+
943
+ num_param = sum([p.numel() for p in model.parameters()])
944
+ logging.info(f"Number of parameters : {num_param}")
945
+
946
+ model_avg: Optional[nn.Module] = None
947
+ if rank == 0:
948
+ # model_avg is only used with rank 0
949
+ model_avg = copy.deepcopy(model).to(torch.float64)
950
+ assert params.start_epoch > 0, params.start_epoch
951
+ if params.start_epoch > 1:
952
+ logging.info(f"Resuming from epoch {params.start_epoch}")
953
+ if params.distill_stage == "first":
954
+ checkpoints = resume_checkpoint(
955
+ params=params, model=model, model_avg=model_avg
956
+ )
957
+ else:
958
+ checkpoints = resume_checkpoint(
959
+ params=params,
960
+ model=model,
961
+ model_avg=model_avg,
962
+ model_ema=teacher_model,
963
+ )
964
+
965
+ model = model.to(params.device)
966
+ teacher_model.to(params.device)
967
+ teacher_model.eval()
968
+
969
+ if world_size > 1:
970
+ logging.info("Using DDP")
971
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
972
+
973
+ # only update the fm_decoder
974
+ num_trainable = 0
975
+ for name, p in model.named_parameters():
976
+ if "fm_decoder" in name:
977
+ p.requires_grad = True
978
+ num_trainable += p.numel()
979
+ else:
980
+ p.requires_grad = False
981
+
982
+ logging.info(
983
+ "A total of {} trainable parameters ({:.3f}% of the whole model)".format(
984
+ num_trainable, num_trainable / num_param * 100
985
+ )
986
+ )
987
+
988
+ optimizer = ScaledAdam(
989
+ get_parameter_groups_with_lrs(
990
+ model,
991
+ lr=params.base_lr,
992
+ include_names=True,
993
+ ),
994
+ lr=params.base_lr, # should have no effect
995
+ clipping_scale=2.0,
996
+ )
997
+
998
+ scheduler = FixedLRScheduler(optimizer)
999
+
1000
+ scaler = GradScaler("cuda", enabled=params.use_fp16)
1001
+
1002
+ if params.start_epoch > 1 and checkpoints is not None:
1003
+ # load state_dict for optimizers
1004
+ if "optimizer" in checkpoints:
1005
+ logging.info("Loading optimizer state dict")
1006
+ optimizer.load_state_dict(checkpoints["optimizer"])
1007
+
1008
+ # load state_dict for schedulers
1009
+ if "scheduler" in checkpoints:
1010
+ logging.info("Loading scheduler state dict")
1011
+ scheduler.load_state_dict(checkpoints["scheduler"])
1012
+
1013
+ if "grad_scaler" in checkpoints:
1014
+ logging.info("Loading grad scaler state dict")
1015
+ scaler.load_state_dict(checkpoints["grad_scaler"])
1016
+
1017
+ if params.print_diagnostics:
1018
+ opts = diagnostics.TensorDiagnosticOptions(
1019
+ 512
1020
+ ) # allow 4 megabytes per sub-module
1021
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
1022
+
1023
+ if params.inf_check:
1024
+ register_inf_check_hooks(model)
1025
+
1026
+ def remove_short_and_long_utt(c: Cut, min_len: float, max_len: float):
1027
+ if c.duration < min_len or c.duration > max_len:
1028
+ return False
1029
+ return True
1030
+
1031
+ _remove_short_and_long_utt = partial(
1032
+ remove_short_and_long_utt, min_len=params.min_len, max_len=params.max_len
1033
+ )
1034
+
1035
+ datamodule = TtsDataModule(args)
1036
+ if params.dataset == "emilia":
1037
+ train_cuts = CutSet.mux(
1038
+ datamodule.train_emilia_EN_cuts(),
1039
+ datamodule.train_emilia_ZH_cuts(),
1040
+ weights=[46000, 49000],
1041
+ )
1042
+ train_cuts = train_cuts.filter(_remove_short_and_long_utt)
1043
+ dev_cuts = CutSet.mux(
1044
+ datamodule.dev_emilia_EN_cuts(),
1045
+ datamodule.dev_emilia_ZH_cuts(),
1046
+ weights=[0.5, 0.5],
1047
+ )
1048
+ elif params.dataset == "libritts":
1049
+ train_cuts = datamodule.train_libritts_cuts()
1050
+ train_cuts = train_cuts.filter(_remove_short_and_long_utt)
1051
+ dev_cuts = datamodule.dev_libritts_cuts()
1052
+ else:
1053
+ assert params.dataset == "custom"
1054
+ train_cuts = datamodule.train_custom_cuts(params.train_manifest)
1055
+ train_cuts = train_cuts.filter(_remove_short_and_long_utt)
1056
+ dev_cuts = datamodule.dev_custom_cuts(params.dev_manifest)
1057
+ # To avoid OOM issues due to too long dev cuts
1058
+ dev_cuts = dev_cuts.filter(_remove_short_and_long_utt)
1059
+
1060
+ _tokenize_text = partial(tokenize_text, tokenizer=tokenizer)
1061
+ train_cuts = train_cuts.map(_tokenize_text)
1062
+ dev_cuts = dev_cuts.map(_tokenize_text)
1063
+
1064
+ train_dl = datamodule.train_dataloaders(train_cuts)
1065
+
1066
+ valid_dl = datamodule.dev_dataloaders(dev_cuts)
1067
+
1068
+ if params.scan_oom:
1069
+ scan_pessimistic_batches_for_oom(
1070
+ model=model,
1071
+ teacher_model=teacher_model,
1072
+ train_dl=train_dl,
1073
+ optimizer=optimizer,
1074
+ params=params,
1075
+ )
1076
+ logging.info("Training started")
1077
+
1078
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
1079
+ logging.info(f"Start epoch {epoch}")
1080
+
1081
+ scheduler.step_epoch(epoch - 1)
1082
+ fix_random_seed(params.seed + epoch - 1)
1083
+ train_dl.sampler.set_epoch(epoch - 1)
1084
+
1085
+ params.cur_epoch = epoch
1086
+
1087
+ if tb_writer is not None:
1088
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
1089
+
1090
+ train_one_epoch(
1091
+ params=params,
1092
+ model=model,
1093
+ model_avg=model_avg,
1094
+ teacher_model=teacher_model,
1095
+ optimizer=optimizer,
1096
+ scheduler=scheduler,
1097
+ train_dl=train_dl,
1098
+ valid_dl=valid_dl,
1099
+ scaler=scaler,
1100
+ tb_writer=tb_writer,
1101
+ world_size=world_size,
1102
+ rank=rank,
1103
+ )
1104
+
1105
+ if params.num_iters > 0 and params.batch_idx_train > params.num_iters:
1106
+ break
1107
+
1108
+ if params.print_diagnostics:
1109
+ diagnostic.print_diagnostics()
1110
+ break
1111
+
1112
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
1113
+ save_checkpoint(
1114
+ filename=filename,
1115
+ params=params,
1116
+ model=model,
1117
+ model_avg=model_avg,
1118
+ model_ema=teacher_model,
1119
+ optimizer=optimizer,
1120
+ scheduler=scheduler,
1121
+ sampler=train_dl.sampler,
1122
+ scaler=scaler,
1123
+ rank=rank,
1124
+ )
1125
+
1126
+ if rank == 0:
1127
+ if params.best_train_epoch == params.cur_epoch:
1128
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
1129
+ copyfile(src=filename, dst=best_train_filename)
1130
+
1131
+ if params.best_valid_epoch == params.cur_epoch:
1132
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
1133
+ copyfile(src=filename, dst=best_valid_filename)
1134
+
1135
+ logging.info("Done!")
1136
+
1137
+ if world_size > 1:
1138
+ torch.distributed.barrier()
1139
+ cleanup_dist()
1140
+
1141
+
1142
+ def main():
1143
+ parser = get_parser()
1144
+ TtsDataModule.add_arguments(parser)
1145
+ args = parser.parse_args()
1146
+ args.exp_dir = Path(args.exp_dir)
1147
+
1148
+ world_size = args.world_size
1149
+ assert world_size >= 1
1150
+ if world_size > 1:
1151
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
1152
+ else:
1153
+ run(rank=0, world_size=1, args=args)
1154
+
1155
+
1156
+ if __name__ == "__main__":
1157
+ torch.set_num_threads(1)
1158
+ torch.set_num_interop_threads(1)
1159
+ main()
zipvoice/dataset/datamodule.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 Piotr Żelasko
2
+ # Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo,
3
+ # Zengwei Yao,
4
+ # Zengrui Jin,
5
+ # Han Zhu,
6
+ # Wei Kang)
7
+ #
8
+ # See ../../../../LICENSE for clarification regarding multiple authors
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+
23
+ import argparse
24
+ import logging
25
+ from functools import lru_cache
26
+ from pathlib import Path
27
+ from typing import Any, Dict, Optional
28
+
29
+ import torch
30
+ from lhotse import CutSet, load_manifest_lazy
31
+ from lhotse.dataset import DynamicBucketingSampler, SimpleCutSampler
32
+ from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures
33
+ from lhotse.utils import fix_random_seed
34
+ from torch.utils.data import DataLoader
35
+
36
+ from zipvoice.dataset.dataset import SpeechSynthesisDataset
37
+ from zipvoice.utils.common import str2bool
38
+ from zipvoice.utils.feature import VocosFbank
39
+
40
+
41
+ class _SeedWorkers:
42
+ def __init__(self, seed: int):
43
+ self.seed = seed
44
+
45
+ def __call__(self, worker_id: int):
46
+ fix_random_seed(self.seed + worker_id)
47
+
48
+
49
+ SAMPLING_RATE = 24000
50
+
51
+
52
+ class TtsDataModule:
53
+ """
54
+ DataModule for tts experiments.
55
+ It assumes there is always one train and valid dataloader,
56
+ but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
57
+ and test-other).
58
+
59
+ It contains all the common data pipeline modules used in ASR
60
+ experiments, e.g.:
61
+ - dynamic batch size,
62
+ - bucketing samplers,
63
+ - cut concatenation,
64
+ - on-the-fly feature extraction
65
+
66
+ This class should be derived for specific corpora used in ASR tasks.
67
+ """
68
+
69
+ def __init__(self, args: argparse.Namespace):
70
+ self.args = args
71
+
72
+ @classmethod
73
+ def add_arguments(cls, parser: argparse.ArgumentParser):
74
+ group = parser.add_argument_group(
75
+ title="TTS data related options",
76
+ description="These options are used for the preparation of "
77
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
78
+ "effective batch sizes, sampling strategies, applied data "
79
+ "augmentations, etc.",
80
+ )
81
+ group.add_argument(
82
+ "--manifest-dir",
83
+ type=Path,
84
+ default=Path("data/fbank"),
85
+ help="Path to directory with train/valid/test cuts.",
86
+ )
87
+ group.add_argument(
88
+ "--max-duration",
89
+ type=int,
90
+ default=200.0,
91
+ help="Maximum pooled recordings duration (seconds) in a "
92
+ "single batch. You can reduce it if it causes CUDA OOM.",
93
+ )
94
+ group.add_argument(
95
+ "--bucketing-sampler",
96
+ type=str2bool,
97
+ default=True,
98
+ help="When enabled, the batches will come from buckets of "
99
+ "similar duration (saves padding frames).",
100
+ )
101
+ group.add_argument(
102
+ "--num-buckets",
103
+ type=int,
104
+ default=30,
105
+ help="The number of buckets for the DynamicBucketingSampler"
106
+ "(you might want to increase it for larger datasets).",
107
+ )
108
+
109
+ group.add_argument(
110
+ "--on-the-fly-feats",
111
+ type=str2bool,
112
+ default=False,
113
+ help="When enabled, use on-the-fly cut mixing and feature "
114
+ "extraction. Will drop existing precomputed feature manifests "
115
+ "if available.",
116
+ )
117
+ group.add_argument(
118
+ "--shuffle",
119
+ type=str2bool,
120
+ default=True,
121
+ help="When enabled (=default), the examples will be "
122
+ "shuffled for each epoch.",
123
+ )
124
+ group.add_argument(
125
+ "--drop-last",
126
+ type=str2bool,
127
+ default=True,
128
+ help="Whether to drop last batch. Used by sampler.",
129
+ )
130
+ group.add_argument(
131
+ "--return-cuts",
132
+ type=str2bool,
133
+ default=False,
134
+ help="When enabled, each batch will have the "
135
+ "field: batch['cut'] with the cuts that "
136
+ "were used to construct it.",
137
+ )
138
+ group.add_argument(
139
+ "--num-workers",
140
+ type=int,
141
+ default=8,
142
+ help="The number of training dataloader workers that "
143
+ "collect the batches.",
144
+ )
145
+
146
+ group.add_argument(
147
+ "--input-strategy",
148
+ type=str,
149
+ default="PrecomputedFeatures",
150
+ help="AudioSamples or PrecomputedFeatures",
151
+ )
152
+
153
+ def train_dataloaders(
154
+ self,
155
+ cuts_train: CutSet,
156
+ sampler_state_dict: Optional[Dict[str, Any]] = None,
157
+ ) -> DataLoader:
158
+ """
159
+ Args:
160
+ cuts_train:
161
+ CutSet for training.
162
+ sampler_state_dict:
163
+ The state dict for the training sampler.
164
+ """
165
+ logging.info("About to create train dataset")
166
+
167
+ train = SpeechSynthesisDataset(
168
+ return_text=True,
169
+ return_tokens=True,
170
+ return_spk_ids=True,
171
+ feature_input_strategy=OnTheFlyFeatures(VocosFbank())
172
+ if self.args.on_the_fly_feats
173
+ else PrecomputedFeatures(),
174
+ return_cuts=self.args.return_cuts,
175
+ )
176
+
177
+ if self.args.bucketing_sampler:
178
+ logging.info("Using DynamicBucketingSampler.")
179
+ train_sampler = DynamicBucketingSampler(
180
+ cuts_train,
181
+ max_duration=self.args.max_duration,
182
+ shuffle=self.args.shuffle,
183
+ num_buckets=self.args.num_buckets,
184
+ buffer_size=self.args.num_buckets * 2000,
185
+ shuffle_buffer_size=self.args.num_buckets * 5000,
186
+ drop_last=self.args.drop_last,
187
+ )
188
+ else:
189
+ logging.info("Using SimpleCutSampler.")
190
+ train_sampler = SimpleCutSampler(
191
+ cuts_train,
192
+ max_duration=self.args.max_duration,
193
+ shuffle=self.args.shuffle,
194
+ )
195
+ logging.info("About to create train dataloader")
196
+
197
+ if sampler_state_dict is not None:
198
+ logging.info("Loading sampler state dict")
199
+ train_sampler.load_state_dict(sampler_state_dict)
200
+
201
+ # 'seed' is derived from the current random state, which will have
202
+ # previously been set in the main process.
203
+ seed = torch.randint(0, 100000, ()).item()
204
+ worker_init_fn = _SeedWorkers(seed)
205
+
206
+ train_dl = DataLoader(
207
+ train,
208
+ sampler=train_sampler,
209
+ batch_size=None,
210
+ num_workers=self.args.num_workers,
211
+ persistent_workers=False,
212
+ worker_init_fn=worker_init_fn,
213
+ )
214
+
215
+ return train_dl
216
+
217
+ def dev_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
218
+ logging.info("About to create dev dataset")
219
+ validate = SpeechSynthesisDataset(
220
+ return_text=True,
221
+ return_tokens=True,
222
+ return_spk_ids=True,
223
+ feature_input_strategy=OnTheFlyFeatures(VocosFbank())
224
+ if self.args.on_the_fly_feats
225
+ else PrecomputedFeatures(),
226
+ return_cuts=self.args.return_cuts,
227
+ )
228
+ dev_sampler = DynamicBucketingSampler(
229
+ cuts_valid,
230
+ max_duration=self.args.max_duration,
231
+ shuffle=False,
232
+ )
233
+ logging.info("About to create valid dataloader")
234
+ dev_dl = DataLoader(
235
+ validate,
236
+ sampler=dev_sampler,
237
+ batch_size=None,
238
+ num_workers=2,
239
+ persistent_workers=False,
240
+ )
241
+
242
+ return dev_dl
243
+
244
+ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
245
+ logging.info("About to create test dataset")
246
+ test = SpeechSynthesisDataset(
247
+ return_text=True,
248
+ return_tokens=True,
249
+ return_spk_ids=True,
250
+ feature_input_strategy=OnTheFlyFeatures(VocosFbank())
251
+ if self.args.on_the_fly_feats
252
+ else PrecomputedFeatures(),
253
+ return_cuts=self.args.return_cuts,
254
+ return_audio=True,
255
+ )
256
+ test_sampler = DynamicBucketingSampler(
257
+ cuts,
258
+ max_duration=self.args.max_duration,
259
+ shuffle=False,
260
+ )
261
+ logging.info("About to create test dataloader")
262
+ test_dl = DataLoader(
263
+ test,
264
+ batch_size=None,
265
+ sampler=test_sampler,
266
+ num_workers=self.args.num_workers,
267
+ )
268
+ return test_dl
269
+
270
+ @lru_cache()
271
+ def train_custom_cuts(self, manifest_file) -> CutSet:
272
+ logging.info(f"About to get the custom training cuts {manifest_file}")
273
+ return load_manifest_lazy(manifest_file)
274
+
275
+ @lru_cache()
276
+ def dev_custom_cuts(self, manifest_file) -> CutSet:
277
+ logging.info(f"About to get the custom validation cuts {manifest_file}")
278
+ return load_manifest_lazy(manifest_file)
279
+
280
+ @lru_cache()
281
+ def train_emilia_EN_cuts(self) -> CutSet:
282
+ logging.info("About to get train the EN subset")
283
+ return load_manifest_lazy(self.args.manifest_dir / "emilia_cuts_EN.jsonl.gz")
284
+
285
+ @lru_cache()
286
+ def train_emilia_ZH_cuts(self) -> CutSet:
287
+ logging.info("About to get train the ZH subset")
288
+ return load_manifest_lazy(self.args.manifest_dir / "emilia_cuts_ZH.jsonl.gz")
289
+
290
+ @lru_cache()
291
+ def dev_emilia_EN_cuts(self) -> CutSet:
292
+ logging.info("About to get dev the EN subset")
293
+ return load_manifest_lazy(
294
+ self.args.manifest_dir / "emilia_cuts_EN-dev.jsonl.gz"
295
+ )
296
+
297
+ @lru_cache()
298
+ def dev_emilia_ZH_cuts(self) -> CutSet:
299
+ logging.info("About to get dev the ZH subset")
300
+ return load_manifest_lazy(
301
+ self.args.manifest_dir / "emilia_cuts_ZH-dev.jsonl.gz"
302
+ )
303
+
304
+ @lru_cache()
305
+ def train_libritts_cuts(self) -> CutSet:
306
+ logging.info(
307
+ "About to get the shuffled train-clean-100, \
308
+ train-clean-360 and train-other-500 cuts"
309
+ )
310
+ return load_manifest_lazy(
311
+ self.args.manifest_dir / "libritts_cuts_train-all-shuf.jsonl.gz"
312
+ )
313
+
314
+ @lru_cache()
315
+ def dev_libritts_cuts(self) -> CutSet:
316
+ logging.info("About to get dev-clean cuts")
317
+ return load_manifest_lazy(
318
+ self.args.manifest_dir / "libritts_cuts_dev-clean.jsonl.gz"
319
+ )
zipvoice/dataset/dataset.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Dict, List, Sequence, Union
2
+
3
+ import torch
4
+ from lhotse import CutSet, validate
5
+ from lhotse.dataset import PrecomputedFeatures
6
+ from lhotse.dataset.collation import collate_audio
7
+ from lhotse.dataset.input_strategies import BatchIO
8
+ from lhotse.utils import ifnone
9
+
10
+
11
+ class SpeechSynthesisDataset(torch.utils.data.Dataset):
12
+ """
13
+ The PyTorch Dataset for the speech synthesis task.
14
+ Each item in this dataset is a dict of:
15
+
16
+ .. code-block::
17
+
18
+ {
19
+ 'audio': (B x NumSamples) float tensor
20
+ 'features': (B x NumFrames x NumFeatures) float tensor
21
+ 'audio_lens': (B, ) int tensor
22
+ 'features_lens': (B, ) int tensor
23
+ 'text': List[str] of len B # when return_text=True
24
+ 'tokens': List[List[str]] # when return_tokens=True
25
+ 'speakers': List[str] of len B # when return_spk_ids=True
26
+ 'cut': List of Cuts # when return_cuts=True
27
+ }
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ cut_transforms: List[Callable[[CutSet], CutSet]] = None,
33
+ feature_input_strategy: BatchIO = PrecomputedFeatures(),
34
+ feature_transforms: Union[Sequence[Callable], Callable] = None,
35
+ return_text: bool = True,
36
+ return_tokens: bool = False,
37
+ return_spk_ids: bool = False,
38
+ return_cuts: bool = False,
39
+ return_audio: bool = False,
40
+ ) -> None:
41
+ super().__init__()
42
+
43
+ self.cut_transforms = ifnone(cut_transforms, [])
44
+ self.feature_input_strategy = feature_input_strategy
45
+
46
+ self.return_text = return_text
47
+ self.return_tokens = return_tokens
48
+ self.return_spk_ids = return_spk_ids
49
+ self.return_cuts = return_cuts
50
+ self.return_audio = return_audio
51
+
52
+ if feature_transforms is None:
53
+ feature_transforms = []
54
+ elif not isinstance(feature_transforms, Sequence):
55
+ feature_transforms = [feature_transforms]
56
+
57
+ assert all(
58
+ isinstance(transform, Callable) for transform in feature_transforms
59
+ ), "Feature transforms must be Callable"
60
+ self.feature_transforms = feature_transforms
61
+
62
+ def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]:
63
+ validate_for_tts(cuts)
64
+
65
+ for transform in self.cut_transforms:
66
+ cuts = transform(cuts)
67
+
68
+ features, features_lens = self.feature_input_strategy(cuts)
69
+
70
+ for transform in self.feature_transforms:
71
+ features = transform(features)
72
+
73
+ batch = {
74
+ "features": features,
75
+ "features_lens": features_lens,
76
+ }
77
+
78
+ if self.return_audio:
79
+ audio, audio_lens = collate_audio(cuts)
80
+ batch["audio"] = audio
81
+ batch["audio_lens"] = audio_lens
82
+
83
+ if self.return_text:
84
+ text = [cut.supervisions[0].text for cut in cuts]
85
+ batch["text"] = text
86
+
87
+ if self.return_tokens:
88
+ tokens = [cut.supervisions[0].tokens for cut in cuts]
89
+ batch["tokens"] = tokens
90
+
91
+ if self.return_spk_ids:
92
+ batch["speakers"] = [cut.supervisions[0].speaker for cut in cuts]
93
+
94
+ if self.return_cuts:
95
+ batch["cut"] = [cut for cut in cuts]
96
+
97
+ return batch
98
+
99
+
100
+ def validate_for_tts(cuts: CutSet) -> None:
101
+ validate(cuts)
102
+ for cut in cuts:
103
+ assert (
104
+ len(cut.supervisions) == 1
105
+ ), "Only the Cuts with single supervision are supported."
zipvoice/eval/evaluate_sim.py ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2025 Xiaomi Corp. (authors: Han Zhu
3
+ # Wei Kang)
4
+ #
5
+ # See ../../../../LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+
20
+ """
21
+ Calculate pairwise Speaker Similarity betweeen two speech directories.
22
+ SV model wavlm_large_finetune.pth is downloaded from
23
+ https://github.com/microsoft/UniSpeech/tree/main/downstreams/speaker_verification
24
+ SSL model wavlm_large.pt is downloaded from
25
+ https://huggingface.co/s3prl/converted_ckpts/resolve/main/wavlm_large.pt
26
+ """
27
+ import argparse
28
+ import logging
29
+ import os
30
+
31
+ import librosa
32
+ import numpy as np
33
+ import soundfile as sf
34
+ import torch
35
+ import torch.nn as nn
36
+ import torch.nn.functional as F
37
+ from tqdm import tqdm
38
+
39
+ logging.basicConfig(level=logging.INFO)
40
+
41
+
42
+ def get_parser():
43
+ parser = argparse.ArgumentParser()
44
+
45
+ parser.add_argument(
46
+ "--eval-path", type=str, help="path of the evaluated speech directory"
47
+ )
48
+ parser.add_argument(
49
+ "--test-list",
50
+ type=str,
51
+ help="path of the file list that contains the corresponding "
52
+ "relationship between the prompt and evaluated speech. "
53
+ "The first column is the wav name and the third column is the prompt speech",
54
+ )
55
+ parser.add_argument(
56
+ "--sv-model-path",
57
+ type=str,
58
+ default="model/UniSpeech/wavlm_large_finetune.pth",
59
+ help="path of the wavlm-based ECAPA-TDNN model",
60
+ )
61
+ parser.add_argument(
62
+ "--ssl-model-path",
63
+ type=str,
64
+ default="model/s3prl/wavlm_large.pt",
65
+ help="path of the wavlm SSL model",
66
+ )
67
+ return parser
68
+
69
+
70
+ class SpeakerSimilarity:
71
+ def __init__(
72
+ self,
73
+ sv_model_path="model/UniSpeech/wavlm_large_finetune.pth",
74
+ ssl_model_path="model/s3prl/wavlm_large.pt",
75
+ ):
76
+ """
77
+ Initialize
78
+ """
79
+ self.sample_rate = 16000
80
+ self.channels = 1
81
+ self.device = (
82
+ torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
83
+ )
84
+ logging.info("[Speaker Similarity] Using device: {}".format(self.device))
85
+ self.model = ECAPA_TDNN_WAVLLM(
86
+ feat_dim=1024,
87
+ channels=512,
88
+ emb_dim=256,
89
+ sr=16000,
90
+ ssl_model_path=ssl_model_path,
91
+ )
92
+ state_dict = torch.load(
93
+ sv_model_path, map_location=lambda storage, loc: storage
94
+ )
95
+ self.model.load_state_dict(state_dict["model"], strict=False)
96
+ self.model.to(self.device)
97
+ self.model.eval()
98
+
99
+ def get_embeddings(self, wav_list, dtype="float32"):
100
+ """
101
+ Get embeddings
102
+ """
103
+
104
+ def _load_speech_task(fname, sample_rate):
105
+
106
+ wav_data, sr = sf.read(fname, dtype=dtype)
107
+ if sr != sample_rate:
108
+ wav_data = librosa.resample(
109
+ wav_data, orig_sr=sr, target_sr=self.sample_rate
110
+ )
111
+ wav_data = torch.from_numpy(wav_data)
112
+
113
+ return wav_data
114
+
115
+ embd_lst = []
116
+ for file_path in tqdm(wav_list):
117
+ speech = _load_speech_task(file_path, self.sample_rate)
118
+ speech = speech.to(self.device)
119
+ with torch.no_grad():
120
+ embd = self.model([speech])
121
+ embd_lst.append(embd)
122
+
123
+ return embd_lst
124
+
125
+ def score(
126
+ self,
127
+ eval_path,
128
+ test_list,
129
+ dtype="float32",
130
+ ):
131
+ """
132
+ Computes the Speaker Similarity (SIM-o) between two directories of speech files.
133
+
134
+ Parameters:
135
+ - eval_path (str): Path to the directory containing evaluation speech files.
136
+ - test_list (str): Path to the file containing the corresponding relationship
137
+ between prompt and evaluated speech.
138
+ - dtype (str, optional): Data type for loading speech. Default is "float32".
139
+
140
+ Returns:
141
+ - float: The Speaker Similarity (SIM-o) score between the two directories
142
+ of speech files.
143
+ """
144
+ prompt_wavs = []
145
+ eval_wavs = []
146
+ with open(test_list, "r") as fr:
147
+ lines = fr.readlines()
148
+ for line in lines:
149
+ wav_name, prompt_text, prompt_wav, text = line.strip().split("\t")
150
+ prompt_wavs.append(prompt_wav)
151
+ eval_wavs.append(os.path.join(eval_path, wav_name + ".wav"))
152
+ embds_prompt = self.get_embeddings(prompt_wavs, dtype=dtype)
153
+
154
+ embds_eval = self.get_embeddings(eval_wavs, dtype=dtype)
155
+
156
+ # Check if embeddings are empty
157
+ if len(embds_prompt) == 0:
158
+ logging.info("[Speaker Similarity] real set dir is empty, exiting...")
159
+ return -1
160
+ if len(embds_eval) == 0:
161
+ logging.info("[Speaker Similarity] eval set dir is empty, exiting...")
162
+ return -1
163
+
164
+ scores = []
165
+ for real_embd, eval_embd in zip(embds_prompt, embds_eval):
166
+ scores.append(
167
+ torch.nn.functional.cosine_similarity(real_embd, eval_embd, dim=-1)
168
+ .detach()
169
+ .cpu()
170
+ .numpy()
171
+ )
172
+
173
+ return np.mean(scores)
174
+
175
+
176
+ # part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
177
+
178
+ """ Res2Conv1d + BatchNorm1d + ReLU
179
+ """
180
+
181
+
182
+ class Res2Conv1dReluBn(nn.Module):
183
+ """
184
+ in_channels == out_channels == channels
185
+ """
186
+
187
+ def __init__(
188
+ self,
189
+ channels,
190
+ kernel_size=1,
191
+ stride=1,
192
+ padding=0,
193
+ dilation=1,
194
+ bias=True,
195
+ scale=4,
196
+ ):
197
+ super().__init__()
198
+ assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
199
+ self.scale = scale
200
+ self.width = channels // scale
201
+ self.nums = scale if scale == 1 else scale - 1
202
+
203
+ self.convs = []
204
+ self.bns = []
205
+ for i in range(self.nums):
206
+ self.convs.append(
207
+ nn.Conv1d(
208
+ self.width,
209
+ self.width,
210
+ kernel_size,
211
+ stride,
212
+ padding,
213
+ dilation,
214
+ bias=bias,
215
+ )
216
+ )
217
+ self.bns.append(nn.BatchNorm1d(self.width))
218
+ self.convs = nn.ModuleList(self.convs)
219
+ self.bns = nn.ModuleList(self.bns)
220
+
221
+ def forward(self, x):
222
+ out = []
223
+ spx = torch.split(x, self.width, 1)
224
+ for i in range(self.nums):
225
+ if i == 0:
226
+ sp = spx[i]
227
+ else:
228
+ sp = sp + spx[i]
229
+ # Order: conv -> relu -> bn
230
+ sp = self.convs[i](sp)
231
+ sp = self.bns[i](F.relu(sp))
232
+ out.append(sp)
233
+ if self.scale != 1:
234
+ out.append(spx[self.nums])
235
+ out = torch.cat(out, dim=1)
236
+
237
+ return out
238
+
239
+
240
+ """ Conv1d + BatchNorm1d + ReLU
241
+ """
242
+
243
+
244
+ class Conv1dReluBn(nn.Module):
245
+ def __init__(
246
+ self,
247
+ in_channels,
248
+ out_channels,
249
+ kernel_size=1,
250
+ stride=1,
251
+ padding=0,
252
+ dilation=1,
253
+ bias=True,
254
+ ):
255
+ super().__init__()
256
+ self.conv = nn.Conv1d(
257
+ in_channels,
258
+ out_channels,
259
+ kernel_size,
260
+ stride,
261
+ padding,
262
+ dilation,
263
+ bias=bias,
264
+ )
265
+ self.bn = nn.BatchNorm1d(out_channels)
266
+
267
+ def forward(self, x):
268
+ return self.bn(F.relu(self.conv(x)))
269
+
270
+
271
+ """ The SE connection of 1D case.
272
+ """
273
+
274
+
275
+ class SE_Connect(nn.Module):
276
+ def __init__(self, channels, se_bottleneck_dim=128):
277
+ super().__init__()
278
+ self.linear1 = nn.Linear(channels, se_bottleneck_dim)
279
+ self.linear2 = nn.Linear(se_bottleneck_dim, channels)
280
+
281
+ def forward(self, x):
282
+ out = x.mean(dim=2)
283
+ out = F.relu(self.linear1(out))
284
+ out = torch.sigmoid(self.linear2(out))
285
+ out = x * out.unsqueeze(2)
286
+
287
+ return out
288
+
289
+
290
+ """ SE-Res2Block of the ECAPA-TDNN architecture.
291
+ """
292
+
293
+
294
+ # def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
295
+ # return nn.Sequential(
296
+ # Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
297
+ # Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
298
+ # Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
299
+ # SE_Connect(channels)
300
+ # )
301
+
302
+
303
+ class SE_Res2Block(nn.Module):
304
+ def __init__(
305
+ self,
306
+ in_channels,
307
+ out_channels,
308
+ kernel_size,
309
+ stride,
310
+ padding,
311
+ dilation,
312
+ scale,
313
+ se_bottleneck_dim,
314
+ ):
315
+ super().__init__()
316
+ self.Conv1dReluBn1 = Conv1dReluBn(
317
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
318
+ )
319
+ self.Res2Conv1dReluBn = Res2Conv1dReluBn(
320
+ out_channels, kernel_size, stride, padding, dilation, scale=scale
321
+ )
322
+ self.Conv1dReluBn2 = Conv1dReluBn(
323
+ out_channels, out_channels, kernel_size=1, stride=1, padding=0
324
+ )
325
+ self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
326
+
327
+ self.shortcut = None
328
+ if in_channels != out_channels:
329
+ self.shortcut = nn.Conv1d(
330
+ in_channels=in_channels,
331
+ out_channels=out_channels,
332
+ kernel_size=1,
333
+ )
334
+
335
+ def forward(self, x):
336
+ residual = x
337
+ if self.shortcut:
338
+ residual = self.shortcut(x)
339
+
340
+ x = self.Conv1dReluBn1(x)
341
+ x = self.Res2Conv1dReluBn(x)
342
+ x = self.Conv1dReluBn2(x)
343
+ x = self.SE_Connect(x)
344
+
345
+ return x + residual
346
+
347
+
348
+ """ Attentive weighted mean and standard deviation pooling.
349
+ """
350
+
351
+
352
+ class AttentiveStatsPool(nn.Module):
353
+ def __init__(self, in_dim, attention_channels=128, global_context_att=False):
354
+ super().__init__()
355
+ self.global_context_att = global_context_att
356
+
357
+ # Use Conv1d with stride == 1 rather than Linear,
358
+ # then we don't need to transpose inputs.
359
+ if global_context_att:
360
+ self.linear1 = nn.Conv1d(
361
+ in_dim * 3, attention_channels, kernel_size=1
362
+ ) # equals W and b in the paper
363
+ else:
364
+ self.linear1 = nn.Conv1d(
365
+ in_dim, attention_channels, kernel_size=1
366
+ ) # equals W and b in the paper
367
+ self.linear2 = nn.Conv1d(
368
+ attention_channels, in_dim, kernel_size=1
369
+ ) # equals V and k in the paper
370
+
371
+ def forward(self, x):
372
+
373
+ if self.global_context_att:
374
+ context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
375
+ context_std = torch.sqrt(
376
+ torch.var(x, dim=-1, keepdim=True) + 1e-10
377
+ ).expand_as(x)
378
+ x_in = torch.cat((x, context_mean, context_std), dim=1)
379
+ else:
380
+ x_in = x
381
+
382
+ # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
383
+ alpha = torch.tanh(self.linear1(x_in))
384
+ # alpha = F.relu(self.linear1(x_in))
385
+ alpha = torch.softmax(self.linear2(alpha), dim=2)
386
+ mean = torch.sum(alpha * x, dim=2)
387
+ residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
388
+ std = torch.sqrt(residuals.clamp(min=1e-9))
389
+ return torch.cat([mean, std], dim=1)
390
+
391
+
392
+ class ECAPA_TDNN_WAVLLM(nn.Module):
393
+ def __init__(
394
+ self,
395
+ feat_dim=80,
396
+ channels=512,
397
+ emb_dim=192,
398
+ global_context_att=False,
399
+ sr=16000,
400
+ ssl_model_path=None,
401
+ ):
402
+ super().__init__()
403
+ self.sr = sr
404
+
405
+ if ssl_model_path is None:
406
+ self.feature_extract = torch.hub.load("s3prl/s3prl", "wavlm_large")
407
+ else:
408
+ self.feature_extract = torch.hub.load(
409
+ os.path.dirname(ssl_model_path),
410
+ "wavlm_local",
411
+ source="local",
412
+ ckpt=ssl_model_path,
413
+ )
414
+
415
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
416
+ self.feature_extract.model.encoder.layers[23].self_attn,
417
+ "fp32_attention",
418
+ ):
419
+ self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = (
420
+ False
421
+ )
422
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
423
+ self.feature_extract.model.encoder.layers[11].self_attn,
424
+ "fp32_attention",
425
+ ):
426
+ self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = (
427
+ False
428
+ )
429
+
430
+ self.feat_num = self.get_feat_num()
431
+ self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
432
+
433
+ self.instance_norm = nn.InstanceNorm1d(feat_dim)
434
+ # self.channels = [channels] * 4 + [channels * 3]
435
+ self.channels = [channels] * 4 + [1536]
436
+
437
+ self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
438
+ self.layer2 = SE_Res2Block(
439
+ self.channels[0],
440
+ self.channels[1],
441
+ kernel_size=3,
442
+ stride=1,
443
+ padding=2,
444
+ dilation=2,
445
+ scale=8,
446
+ se_bottleneck_dim=128,
447
+ )
448
+ self.layer3 = SE_Res2Block(
449
+ self.channels[1],
450
+ self.channels[2],
451
+ kernel_size=3,
452
+ stride=1,
453
+ padding=3,
454
+ dilation=3,
455
+ scale=8,
456
+ se_bottleneck_dim=128,
457
+ )
458
+ self.layer4 = SE_Res2Block(
459
+ self.channels[2],
460
+ self.channels[3],
461
+ kernel_size=3,
462
+ stride=1,
463
+ padding=4,
464
+ dilation=4,
465
+ scale=8,
466
+ se_bottleneck_dim=128,
467
+ )
468
+
469
+ # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
470
+ cat_channels = channels * 3
471
+ self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
472
+ self.pooling = AttentiveStatsPool(
473
+ self.channels[-1],
474
+ attention_channels=128,
475
+ global_context_att=global_context_att,
476
+ )
477
+ self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
478
+ self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
479
+
480
+ def get_feat_num(self):
481
+ self.feature_extract.eval()
482
+ wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
483
+ with torch.no_grad():
484
+ features = self.feature_extract(wav)
485
+ select_feature = features["hidden_states"]
486
+ if isinstance(select_feature, (list, tuple)):
487
+ return len(select_feature)
488
+ else:
489
+ return 1
490
+
491
+ def get_feat(self, x):
492
+ with torch.no_grad():
493
+ x = self.feature_extract([sample for sample in x])
494
+
495
+ x = x["hidden_states"]
496
+ if isinstance(x, (list, tuple)):
497
+ x = torch.stack(x, dim=0)
498
+ else:
499
+ x = x.unsqueeze(0)
500
+ norm_weights = (
501
+ F.softmax(self.feature_weight, dim=-1)
502
+ .unsqueeze(-1)
503
+ .unsqueeze(-1)
504
+ .unsqueeze(-1)
505
+ )
506
+ x = (norm_weights * x).sum(dim=0)
507
+ x = torch.transpose(x, 1, 2) + 1e-6
508
+
509
+ x = self.instance_norm(x)
510
+ return x
511
+
512
+ def forward(self, x):
513
+ x = self.get_feat(x)
514
+
515
+ out1 = self.layer1(x)
516
+ out2 = self.layer2(out1)
517
+ out3 = self.layer3(out2)
518
+ out4 = self.layer4(out3)
519
+
520
+ out = torch.cat([out2, out3, out4], dim=1)
521
+ out = F.relu(self.conv(out))
522
+ out = self.bn(self.pooling(out))
523
+ out = self.linear(out)
524
+
525
+ return out
526
+
527
+
528
+ if __name__ == "__main__":
529
+ parser = get_parser()
530
+ args = parser.parse_args()
531
+ SIM = SpeakerSimilarity(
532
+ sv_model_path=args.sv_model_path, ssl_model_path=args.ssl_model_path
533
+ )
534
+ score = SIM.score(args.eval_path, args.test_list)
535
+ logging.info(f"SIM-o score: {score:.3f}")
zipvoice/eval/evaluate_utmos.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2025 Xiaomi Corp. (authors: Han Zhu
3
+ # Wei Kang)
4
+ #
5
+ # See ../../../../LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+
20
+ """
21
+ Calculate UTMOS score with automatic Mean Opinion Score (MOS) prediction system
22
+ adapted from https://huggingface.co/spaces/sarulab-speech/UTMOS-demo
23
+
24
+ # Download model checkpoints
25
+ wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/epoch%3D3-step%3D7459.ckpt -P model/huggingface/utmos/utmos.pt
26
+ wget https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/wav2vec_small.pt -P model/huggingface/utmos/wav2vec_small.pt
27
+ """
28
+
29
+ import argparse
30
+ import logging
31
+ import os
32
+
33
+ import fairseq
34
+ import librosa
35
+ import numpy as np
36
+ import pytorch_lightning as pl
37
+ import soundfile as sf
38
+ import torch
39
+ import torch.nn as nn
40
+ from tqdm import tqdm
41
+
42
+ logging.basicConfig(level=logging.INFO)
43
+
44
+
45
+ def get_parser():
46
+ parser = argparse.ArgumentParser()
47
+
48
+ parser.add_argument(
49
+ "--wav-path", type=str, help="path of the evaluated speech directory"
50
+ )
51
+ parser.add_argument(
52
+ "--utmos-model-path",
53
+ type=str,
54
+ default="model/huggingface/utmos/utmos.pt",
55
+ help="path of the UTMOS model",
56
+ )
57
+ parser.add_argument(
58
+ "--ssl-model-path",
59
+ type=str,
60
+ default="model/huggingface/utmos/wav2vec_small.pt",
61
+ help="path of the wav2vec SSL model",
62
+ )
63
+ return parser
64
+
65
+
66
+ class UTMOSScore:
67
+ """Predicting score for each audio clip."""
68
+
69
+ def __init__(self, utmos_model_path, ssl_model_path):
70
+ self.sample_rate = 16000
71
+ self.device = (
72
+ torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
73
+ )
74
+ self.model = (
75
+ BaselineLightningModule.load_from_checkpoint(
76
+ utmos_model_path, ssl_model_path=ssl_model_path
77
+ )
78
+ .eval()
79
+ .to(self.device)
80
+ )
81
+
82
+ def score(self, wavs: torch.Tensor) -> torch.Tensor:
83
+ """
84
+ Args:
85
+ wavs: waveforms to be evaluated. When len(wavs) == 1 or 2,
86
+ the model processes the input as a single audio clip. The model
87
+ performs batch processing when len(wavs) == 3.
88
+ """
89
+ if len(wavs.shape) == 1:
90
+ out_wavs = wavs.unsqueeze(0).unsqueeze(0)
91
+ elif len(wavs.shape) == 2:
92
+ out_wavs = wavs.unsqueeze(0)
93
+ elif len(wavs.shape) == 3:
94
+ out_wavs = wavs
95
+ else:
96
+ raise ValueError("Dimension of input tensor needs to be <= 3.")
97
+ bs = out_wavs.shape[0]
98
+ batch = {
99
+ "wav": out_wavs,
100
+ "domains": torch.zeros(bs, dtype=torch.int).to(self.device),
101
+ "judge_id": torch.ones(bs, dtype=torch.int).to(self.device) * 288,
102
+ }
103
+ with torch.no_grad():
104
+ output = self.model(batch)
105
+
106
+ return output.mean(dim=1).squeeze(1).cpu().detach() * 2 + 3
107
+
108
+ def score_dir(self, dir, dtype="float32"):
109
+ def _load_speech_task(fname, sample_rate):
110
+
111
+ wav_data, sr = sf.read(fname, dtype=dtype)
112
+ if sr != sample_rate:
113
+ wav_data = librosa.resample(
114
+ wav_data, orig_sr=sr, target_sr=self.sample_rate
115
+ )
116
+ wav_data = torch.from_numpy(wav_data)
117
+
118
+ return wav_data
119
+
120
+ score_lst = []
121
+ for fname in tqdm(os.listdir(dir)):
122
+ speech = _load_speech_task(os.path.join(dir, fname), self.sample_rate)
123
+ speech = speech.to(self.device)
124
+ with torch.no_grad():
125
+ score = self.score(speech)
126
+ score_lst.append(score.item())
127
+ return np.mean(score_lst)
128
+
129
+
130
+ def load_ssl_model(ckpt_path="wav2vec_small.pt"):
131
+ SSL_OUT_DIM = 768
132
+ model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
133
+ [ckpt_path]
134
+ )
135
+ ssl_model = model[0]
136
+ ssl_model.remove_pretraining_modules()
137
+ return SSL_model(ssl_model, SSL_OUT_DIM)
138
+
139
+
140
+ class BaselineLightningModule(pl.LightningModule):
141
+ def __init__(self, ssl_model_path):
142
+ super().__init__()
143
+ self.construct_model(ssl_model_path)
144
+ self.save_hyperparameters()
145
+
146
+ def construct_model(self, ssl_model_path):
147
+ self.feature_extractors = nn.ModuleList(
148
+ [
149
+ load_ssl_model(ckpt_path=ssl_model_path),
150
+ DomainEmbedding(3, 128),
151
+ ]
152
+ )
153
+ output_dim = sum(
154
+ [
155
+ feature_extractor.get_output_dim()
156
+ for feature_extractor in self.feature_extractors
157
+ ]
158
+ )
159
+ output_layers = [
160
+ LDConditioner(judge_dim=128, num_judges=3000, input_dim=output_dim)
161
+ ]
162
+ output_dim = output_layers[-1].get_output_dim()
163
+ output_layers.append(
164
+ Projection(
165
+ hidden_dim=2048,
166
+ activation=torch.nn.ReLU(),
167
+ range_clipping=False,
168
+ input_dim=output_dim,
169
+ )
170
+ )
171
+
172
+ self.output_layers = nn.ModuleList(output_layers)
173
+
174
+ def forward(self, inputs):
175
+ outputs = {}
176
+ for feature_extractor in self.feature_extractors:
177
+ outputs.update(feature_extractor(inputs))
178
+ x = outputs
179
+ for output_layer in self.output_layers:
180
+ x = output_layer(x, inputs)
181
+ return x
182
+
183
+
184
+ class SSL_model(nn.Module):
185
+ def __init__(self, ssl_model, ssl_out_dim) -> None:
186
+ super(SSL_model, self).__init__()
187
+ self.ssl_model, self.ssl_out_dim = ssl_model, ssl_out_dim
188
+
189
+ def forward(self, batch):
190
+ wav = batch["wav"]
191
+ wav = wav.squeeze(1) # [batches, wav_len]
192
+ res = self.ssl_model(wav, mask=False, features_only=True)
193
+ x = res["x"]
194
+ return {"ssl-feature": x}
195
+
196
+ def get_output_dim(self):
197
+ return self.ssl_out_dim
198
+
199
+
200
+ class DomainEmbedding(nn.Module):
201
+ def __init__(self, n_domains, domain_dim) -> None:
202
+ super().__init__()
203
+ self.embedding = nn.Embedding(n_domains, domain_dim)
204
+ self.output_dim = domain_dim
205
+
206
+ def forward(self, batch):
207
+ return {"domain-feature": self.embedding(batch["domains"])}
208
+
209
+ def get_output_dim(self):
210
+ return self.output_dim
211
+
212
+
213
+ class LDConditioner(nn.Module):
214
+ """
215
+ Conditions ssl output by listener embedding
216
+ """
217
+
218
+ def __init__(self, input_dim, judge_dim, num_judges=None):
219
+ super().__init__()
220
+ self.input_dim = input_dim
221
+ self.judge_dim = judge_dim
222
+ self.num_judges = num_judges
223
+ assert num_judges is not None
224
+ self.judge_embedding = nn.Embedding(num_judges, self.judge_dim)
225
+ # concat [self.output_layer, phoneme features]
226
+
227
+ self.decoder_rnn = nn.LSTM(
228
+ input_size=self.input_dim + self.judge_dim,
229
+ hidden_size=512,
230
+ num_layers=1,
231
+ batch_first=True,
232
+ bidirectional=True,
233
+ ) # linear?
234
+ self.out_dim = self.decoder_rnn.hidden_size * 2
235
+
236
+ def get_output_dim(self):
237
+ return self.out_dim
238
+
239
+ def forward(self, x, batch):
240
+ judge_ids = batch["judge_id"]
241
+ if "phoneme-feature" in x.keys():
242
+ concatenated_feature = torch.cat(
243
+ (
244
+ x["ssl-feature"],
245
+ x["phoneme-feature"]
246
+ .unsqueeze(1)
247
+ .expand(-1, x["ssl-feature"].size(1), -1),
248
+ ),
249
+ dim=2,
250
+ )
251
+ else:
252
+ concatenated_feature = x["ssl-feature"]
253
+ if "domain-feature" in x.keys():
254
+ concatenated_feature = torch.cat(
255
+ (
256
+ concatenated_feature,
257
+ x["domain-feature"]
258
+ .unsqueeze(1)
259
+ .expand(-1, concatenated_feature.size(1), -1),
260
+ ),
261
+ dim=2,
262
+ )
263
+ if judge_ids is not None:
264
+ concatenated_feature = torch.cat(
265
+ (
266
+ concatenated_feature,
267
+ self.judge_embedding(judge_ids)
268
+ .unsqueeze(1)
269
+ .expand(-1, concatenated_feature.size(1), -1),
270
+ ),
271
+ dim=2,
272
+ )
273
+ decoder_output, (h, c) = self.decoder_rnn(concatenated_feature)
274
+ return decoder_output
275
+
276
+
277
+ class Projection(nn.Module):
278
+ def __init__(self, input_dim, hidden_dim, activation, range_clipping=False):
279
+ super(Projection, self).__init__()
280
+ self.range_clipping = range_clipping
281
+ output_dim = 1
282
+ if range_clipping:
283
+ self.proj = nn.Tanh()
284
+
285
+ self.net = nn.Sequential(
286
+ nn.Linear(input_dim, hidden_dim),
287
+ activation,
288
+ nn.Dropout(0.3),
289
+ nn.Linear(hidden_dim, output_dim),
290
+ )
291
+ self.output_dim = output_dim
292
+
293
+ def forward(self, x, batch):
294
+ output = self.net(x)
295
+
296
+ # range clipping
297
+ if self.range_clipping:
298
+ return self.proj(output) * 2.0 + 3
299
+ else:
300
+ return output
301
+
302
+ def get_output_dim(self):
303
+ return self.output_dim
304
+
305
+
306
+ if __name__ == "__main__":
307
+ parser = get_parser()
308
+ args = parser.parse_args()
309
+ UTMOS = UTMOSScore(
310
+ utmos_model_path=args.utmos_model_path,
311
+ ssl_model_path=args.ssl_model_path,
312
+ )
313
+ score = UTMOS.score_dir(args.wav_path)
314
+ logging.info(f"UTMOS score: {score:.2f}")
zipvoice/eval/evaluate_wer_hubert.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2025 Xiaomi Corp. (authors: Han Zhu,
3
+ # Wei Kang)
4
+ #
5
+ # See ../../../../LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+
20
+ """
21
+ Calculate WER with Hubert models.
22
+ """
23
+ import argparse
24
+ import os
25
+ import re
26
+ from pathlib import Path
27
+
28
+ import librosa
29
+ import numpy as np
30
+ import soundfile as sf
31
+ import torch
32
+ from jiwer import compute_measures
33
+ from tqdm import tqdm
34
+ from transformers import pipeline
35
+
36
+
37
+ def get_parser():
38
+ parser = argparse.ArgumentParser()
39
+
40
+ parser.add_argument("--wav-path", type=str, help="path of the speech directory")
41
+ parser.add_argument(
42
+ "--decode-path",
43
+ type=str,
44
+ default=None,
45
+ help="path of the output file of WER information",
46
+ )
47
+ parser.add_argument(
48
+ "--model-path",
49
+ type=str,
50
+ default=None,
51
+ help="path of the local hubert model, e.g., "
52
+ "model/huggingface/hubert-large-ls960-ft",
53
+ )
54
+ parser.add_argument(
55
+ "--test-list",
56
+ type=str,
57
+ default="test.tsv",
58
+ help="path of the transcript tsv file, where the first column "
59
+ "is the wav name and the last column is the transcript",
60
+ )
61
+ parser.add_argument(
62
+ "--batch-size", type=int, default=16, help="decoding batch size"
63
+ )
64
+ return parser
65
+
66
+
67
+ def post_process(text: str):
68
+ text = text.replace("‘", "'")
69
+ text = text.replace("’", "'")
70
+ text = re.sub(r"[^a-zA-Z0-9']", " ", text.lower())
71
+ text = re.sub(r"\s+", " ", text)
72
+ text = text.strip()
73
+ return text
74
+
75
+
76
+ def process_one(hypo, truth):
77
+ truth = post_process(truth)
78
+ hypo = post_process(hypo)
79
+
80
+ measures = compute_measures(truth, hypo)
81
+ word_num = len(truth.split(" "))
82
+ wer = measures["wer"]
83
+ subs = measures["substitutions"]
84
+ dele = measures["deletions"]
85
+ inse = measures["insertions"]
86
+ return (truth, hypo, wer, subs, dele, inse, word_num)
87
+
88
+
89
+ class SpeechEvalDataset(torch.utils.data.Dataset):
90
+ def __init__(self, wav_path: str, test_list: str):
91
+ super().__init__()
92
+ self.wav_name = []
93
+ self.wav_paths = []
94
+ self.transcripts = []
95
+ with Path(test_list).open("r", encoding="utf8") as f:
96
+ meta = [item.split("\t") for item in f.read().rstrip().split("\n")]
97
+ for item in meta:
98
+ self.wav_name.append(item[0])
99
+ self.wav_paths.append(Path(wav_path, item[0] + ".wav"))
100
+ self.transcripts.append(item[-1])
101
+
102
+ def __len__(self):
103
+ return len(self.wav_paths)
104
+
105
+ def __getitem__(self, index: int):
106
+ wav, sampling_rate = sf.read(self.wav_paths[index])
107
+ item = {
108
+ "array": librosa.resample(wav, orig_sr=sampling_rate, target_sr=16000),
109
+ "sampling_rate": 16000,
110
+ "reference": self.transcripts[index],
111
+ "wav_name": self.wav_name[index],
112
+ }
113
+ return item
114
+
115
+
116
+ def main(test_list, wav_path, model_path, decode_path, batch_size, device):
117
+
118
+ if model_path is not None:
119
+ pipe = pipeline(
120
+ "automatic-speech-recognition",
121
+ model=model_path,
122
+ device=device,
123
+ tokenizer=model_path,
124
+ )
125
+ else:
126
+ pipe = pipeline(
127
+ "automatic-speech-recognition",
128
+ model="facebook/hubert-large-ls960-ft",
129
+ device=device,
130
+ )
131
+
132
+ dataset = SpeechEvalDataset(wav_path, test_list)
133
+
134
+ bar = tqdm(
135
+ pipe(
136
+ dataset,
137
+ generate_kwargs={"language": "english", "task": "transcribe"},
138
+ batch_size=batch_size,
139
+ ),
140
+ total=len(dataset),
141
+ )
142
+
143
+ wers = []
144
+ inses = []
145
+ deles = []
146
+ subses = []
147
+ word_nums = 0
148
+ if decode_path:
149
+ decode_dir = os.path.dirname(decode_path)
150
+ if not os.path.exists(decode_dir):
151
+ os.makedirs(decode_dir)
152
+ fout = open(decode_path, "w")
153
+ for out in bar:
154
+ wav_name = out["wav_name"][0]
155
+ transcription = post_process(out["text"].strip())
156
+ text_ref = post_process(out["reference"][0].strip())
157
+ truth, hypo, wer, subs, dele, inse, word_num = process_one(
158
+ transcription, text_ref
159
+ )
160
+ if decode_path:
161
+ fout.write(f"{wav_name}\t{wer}\t{truth}\t{hypo}\t{inse}\t{dele}\t{subs}\n")
162
+ wers.append(float(wer))
163
+ inses.append(float(inse))
164
+ deles.append(float(dele))
165
+ subses.append(float(subs))
166
+ word_nums += word_num
167
+
168
+ wer = round((np.sum(subses) + np.sum(deles) + np.sum(inses)) / word_nums * 100, 3)
169
+ subs = round(np.mean(subses) * 100, 3)
170
+ dele = round(np.mean(deles) * 100, 3)
171
+ inse = round(np.mean(inses) * 100, 3)
172
+ print(f"WER: {wer}%\n")
173
+ if decode_path:
174
+ fout.write(f"WER: {wer}%\n")
175
+ fout.flush()
176
+
177
+
178
+ if __name__ == "__main__":
179
+ parser = get_parser()
180
+ args = parser.parse_args()
181
+ if torch.cuda.is_available():
182
+ device = torch.device("cuda", 0)
183
+ else:
184
+ device = torch.device("cpu")
185
+ main(
186
+ args.test_list,
187
+ args.wav_path,
188
+ args.model_path,
189
+ args.decode_path,
190
+ args.batch_size,
191
+ device,
192
+ )
zipvoice/eval/evaluate_wer_seedtts.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2025 Xiaomi Corp. (authors: Han Zhu
3
+ # Wei Kang)
4
+ #
5
+ # See ../../../../LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+
20
+ """
21
+ Calculate WER with Whisper-large-v3 or Paraformer models,
22
+ following Seed-TTS https://github.com/BytedanceSpeech/seed-tts-eval
23
+ """
24
+
25
+ import argparse
26
+ import os
27
+ import string
28
+
29
+ import numpy as np
30
+ import scipy
31
+ import soundfile as sf
32
+ import torch
33
+ import zhconv
34
+ from funasr import AutoModel
35
+ from jiwer import compute_measures
36
+ from tqdm import tqdm
37
+ from transformers import WhisperForConditionalGeneration, WhisperProcessor
38
+ from zhon.hanzi import punctuation
39
+
40
+
41
+ def get_parser():
42
+ parser = argparse.ArgumentParser()
43
+
44
+ parser.add_argument("--wav-path", type=str, help="path of the speech directory")
45
+ parser.add_argument(
46
+ "--decode-path",
47
+ type=str,
48
+ default=None,
49
+ help="path of the output file of WER information",
50
+ )
51
+ parser.add_argument(
52
+ "--model-path",
53
+ type=str,
54
+ default=None,
55
+ help="path of the local whisper and paraformer model, "
56
+ "e.g., whisper: model/huggingface/whisper-large-v3/, "
57
+ "paraformer: model/huggingface/paraformer-zh/",
58
+ )
59
+ parser.add_argument(
60
+ "--test-list",
61
+ type=str,
62
+ default="test.tsv",
63
+ help="path of the transcript tsv file, where the first column "
64
+ "is the wav name and the last column is the transcript",
65
+ )
66
+ parser.add_argument("--lang", type=str, help="decoded language, zh or en")
67
+ return parser
68
+
69
+
70
+ def load_en_model(model_path):
71
+ if model_path is None:
72
+ model_path = "openai/whisper-large-v3"
73
+ processor = WhisperProcessor.from_pretrained(model_path)
74
+ model = WhisperForConditionalGeneration.from_pretrained(model_path)
75
+ return processor, model
76
+
77
+
78
+ def load_zh_model(model_path):
79
+ if model_path is None:
80
+ model_path = "paraformer-zh"
81
+ model = AutoModel(model=model_path)
82
+ return model
83
+
84
+
85
+ def process_one(hypo, truth, lang):
86
+ punctuation_all = punctuation + string.punctuation
87
+ for x in punctuation_all:
88
+ if x == "'":
89
+ continue
90
+ truth = truth.replace(x, "")
91
+ hypo = hypo.replace(x, "")
92
+
93
+ truth = truth.replace(" ", " ")
94
+ hypo = hypo.replace(" ", " ")
95
+
96
+ if lang == "zh":
97
+ truth = " ".join([x for x in truth])
98
+ hypo = " ".join([x for x in hypo])
99
+ elif lang == "en":
100
+ truth = truth.lower()
101
+ hypo = hypo.lower()
102
+ else:
103
+ raise NotImplementedError
104
+
105
+ measures = compute_measures(truth, hypo)
106
+ word_num = len(truth.split(" "))
107
+ wer = measures["wer"]
108
+ subs = measures["substitutions"]
109
+ dele = measures["deletions"]
110
+ inse = measures["insertions"]
111
+ return (truth, hypo, wer, subs, dele, inse, word_num)
112
+
113
+
114
+ def main(test_list, wav_path, model_path, decode_path, lang, device):
115
+ if lang == "en":
116
+ processor, model = load_en_model(model_path)
117
+ model.to(device)
118
+ elif lang == "zh":
119
+ model = load_zh_model(model_path)
120
+ params = []
121
+ for line in open(test_list).readlines():
122
+ line = line.strip()
123
+ items = line.split("\t")
124
+ wav_name, text_ref = items[0], items[-1]
125
+ file_path = os.path.join(wav_path, wav_name + ".wav")
126
+ assert os.path.exists(file_path), f"{file_path}"
127
+
128
+ params.append((file_path, text_ref))
129
+ wers = []
130
+ inses = []
131
+ deles = []
132
+ subses = []
133
+ word_nums = 0
134
+ if decode_path:
135
+ decode_dir = os.path.dirname(decode_path)
136
+ if not os.path.exists(decode_dir):
137
+ os.makedirs(decode_dir)
138
+ fout = open(decode_path, "w")
139
+ for wav_path, text_ref in tqdm(params):
140
+ if lang == "en":
141
+ wav, sr = sf.read(wav_path)
142
+ if sr != 16000:
143
+ wav = scipy.signal.resample(wav, int(len(wav) * 16000 / sr))
144
+ input_features = processor(
145
+ wav, sampling_rate=16000, return_tensors="pt"
146
+ ).input_features
147
+ input_features = input_features.to(device)
148
+ forced_decoder_ids = processor.get_decoder_prompt_ids(
149
+ language="english", task="transcribe"
150
+ )
151
+ predicted_ids = model.generate(
152
+ input_features, forced_decoder_ids=forced_decoder_ids
153
+ )
154
+ transcription = processor.batch_decode(
155
+ predicted_ids, skip_special_tokens=True
156
+ )[0]
157
+ elif lang == "zh":
158
+ res = model.generate(input=wav_path, batch_size_s=300, disable_pbar=True)
159
+ transcription = res[0]["text"]
160
+ transcription = zhconv.convert(transcription, "zh-cn")
161
+
162
+ truth, hypo, wer, subs, dele, inse, word_num = process_one(
163
+ transcription, text_ref, lang
164
+ )
165
+ if decode_path:
166
+ fout.write(f"{wav_path}\t{wer}\t{truth}\t{hypo}\t{inse}\t{dele}\t{subs}\n")
167
+ wers.append(float(wer))
168
+ inses.append(float(inse))
169
+ deles.append(float(dele))
170
+ subses.append(float(subs))
171
+ word_nums += word_num
172
+
173
+ wer_avg = round(np.mean(wers) * 100, 3)
174
+ wer = round((np.sum(subses) + np.sum(deles) + np.sum(inses)) / word_nums * 100, 3)
175
+ subs = round(np.mean(subses) * 100, 3)
176
+ dele = round(np.mean(deles) * 100, 3)
177
+ inse = round(np.mean(inses) * 100, 3)
178
+ print(f"Seed-TTS WER: {wer_avg}%\n")
179
+ print(f"WER: {wer}%\n")
180
+ if decode_path:
181
+ fout.write(f"SeedTTS WER: {wer_avg}%\n")
182
+ fout.write(f"WER: {wer}%\n")
183
+ fout.flush()
184
+
185
+
186
+ if __name__ == "__main__":
187
+ parser = get_parser()
188
+ args = parser.parse_args()
189
+ if torch.cuda.is_available():
190
+ device = torch.device("cuda", 0)
191
+ else:
192
+ device = torch.device("cpu")
193
+ main(
194
+ args.test_list,
195
+ args.wav_path,
196
+ args.model_path,
197
+ args.decode_path,
198
+ args.lang,
199
+ device,
200
+ )
zipvoice/models/modules/scaling.py ADDED
@@ -0,0 +1,1563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-2025 Xiaomi Corp. (authors: Daniel Povey
2
+ # Wei Kang)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+
19
+ import logging
20
+ import math
21
+ import random
22
+ import sys
23
+ from typing import Optional, Tuple, Union
24
+
25
+ try:
26
+ import k2
27
+ except Exception as e:
28
+ logging.warning(
29
+ f"Failed import k2 with error {e}. Swoosh functions will fallback to PyTorch"
30
+ f" implementation, leading to slower speed and higher memory consumption."
31
+ )
32
+ import torch
33
+ import torch.nn as nn
34
+ from torch import Tensor
35
+
36
+ custom_bwd = lambda func: torch.amp.custom_bwd(func, device_type="cuda")
37
+ custom_fwd = lambda func: torch.amp.custom_fwd(func, device_type="cuda")
38
+
39
+
40
+ def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:
41
+ max_value = torch.max(x, y)
42
+ diff = torch.abs(x - y)
43
+ return max_value + torch.log1p(torch.exp(-diff))
44
+
45
+
46
+ # RuntimeError: Exporting the operator logaddexp to ONNX opset version
47
+ # 14 is not supported. Please feel free to request support or submit
48
+ # a pull request on PyTorch GitHub.
49
+ #
50
+ # The following function is to solve the above error when exporting
51
+ # models to ONNX via torch.jit.trace()
52
+ def logaddexp(x: Tensor, y: Tensor) -> Tensor:
53
+ # Caution(fangjun): Put torch.jit.is_scripting() before
54
+ # torch.onnx.is_in_onnx_export();
55
+ # otherwise, it will cause errors for torch.jit.script().
56
+ #
57
+ # torch.logaddexp() works for both torch.jit.script() and
58
+ # torch.jit.trace() but it causes errors for ONNX export.
59
+ #
60
+ if torch.jit.is_scripting():
61
+ # Note: We cannot use torch.jit.is_tracing() here as it also
62
+ # matches torch.onnx.export().
63
+ return torch.logaddexp(x, y)
64
+ elif torch.onnx.is_in_onnx_export():
65
+ return logaddexp_onnx(x, y)
66
+ else:
67
+ # for torch.jit.trace()
68
+ return torch.logaddexp(x, y)
69
+
70
+
71
+ class PiecewiseLinear(object):
72
+ """
73
+ Piecewise linear function, from float to float, specified as nonempty list of (x,y)
74
+ pairs with the x values in order. x values <[initial x] or >[final x] are map to
75
+ [initial y], [final y] respectively.
76
+ """
77
+
78
+ def __init__(self, *args):
79
+ assert len(args) >= 1, len(args)
80
+ if len(args) == 1 and isinstance(args[0], PiecewiseLinear):
81
+ self.pairs = list(args[0].pairs)
82
+ else:
83
+ self.pairs = [(float(x), float(y)) for x, y in args]
84
+ for x, y in self.pairs:
85
+ assert isinstance(x, (float, int)), type(x)
86
+ assert isinstance(y, (float, int)), type(y)
87
+
88
+ for i in range(len(self.pairs) - 1):
89
+ assert self.pairs[i + 1][0] > self.pairs[i][0], (
90
+ i,
91
+ self.pairs[i],
92
+ self.pairs[i + 1],
93
+ )
94
+
95
+ def __str__(self):
96
+ # e.g. 'PiecewiseLinear((0., 10.), (100., 0.))'
97
+ return f"PiecewiseLinear({str(self.pairs)[1:-1]})"
98
+
99
+ def __call__(self, x):
100
+ if x <= self.pairs[0][0]:
101
+ return self.pairs[0][1]
102
+ elif x >= self.pairs[-1][0]:
103
+ return self.pairs[-1][1]
104
+ else:
105
+ cur_x, cur_y = self.pairs[0]
106
+ for i in range(1, len(self.pairs)):
107
+ next_x, next_y = self.pairs[i]
108
+ if x >= cur_x and x <= next_x:
109
+ return cur_y + (next_y - cur_y) * (x - cur_x) / (next_x - cur_x)
110
+ cur_x, cur_y = next_x, next_y
111
+ assert False
112
+
113
+ def __mul__(self, alpha):
114
+ return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs])
115
+
116
+ def __add__(self, x):
117
+ if isinstance(x, (float, int)):
118
+ return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs])
119
+ s, x = self.get_common_basis(x)
120
+ return PiecewiseLinear(
121
+ *[(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)]
122
+ )
123
+
124
+ def max(self, x):
125
+ if isinstance(x, (float, int)):
126
+ x = PiecewiseLinear((0, x))
127
+ s, x = self.get_common_basis(x, include_crossings=True)
128
+ return PiecewiseLinear(
129
+ *[(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]
130
+ )
131
+
132
+ def min(self, x):
133
+ if isinstance(x, float) or isinstance(x, int):
134
+ x = PiecewiseLinear((0, x))
135
+ s, x = self.get_common_basis(x, include_crossings=True)
136
+ return PiecewiseLinear(
137
+ *[(sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]
138
+ )
139
+
140
+ def __eq__(self, other):
141
+ return self.pairs == other.pairs
142
+
143
+ def get_common_basis(self, p: "PiecewiseLinear", include_crossings: bool = False):
144
+ """
145
+ Returns (self_mod, p_mod) which are equivalent piecewise linear
146
+ functions to self and p, but with the same x values.
147
+
148
+ p: the other piecewise linear function
149
+ include_crossings: if true, include in the x values positions
150
+ where the functions indicate by this and p crosss.
151
+ """
152
+ assert isinstance(p, PiecewiseLinear), type(p)
153
+
154
+ # get sorted x-values without repetition.
155
+ x_vals = sorted(set([x for x, _ in self.pairs] + [x for x, _ in p.pairs]))
156
+ y_vals1 = [self(x) for x in x_vals]
157
+ y_vals2 = [p(x) for x in x_vals]
158
+
159
+ if include_crossings:
160
+ extra_x_vals = []
161
+ for i in range(len(x_vals) - 1):
162
+ if (y_vals1[i] > y_vals2[i]) != (y_vals1[i + 1] > y_vals2[i + 1]):
163
+ # if the two lines in this subsegment potentially cross each other..
164
+ diff_cur = abs(y_vals1[i] - y_vals2[i])
165
+ diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1])
166
+ # `pos`, between 0 and 1, gives the relative x position,
167
+ # with 0 being x_vals[i] and 1 being x_vals[i+1].
168
+ pos = diff_cur / (diff_cur + diff_next)
169
+ extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i])
170
+ extra_x_vals.append(extra_x_val)
171
+ if len(extra_x_vals) > 0:
172
+ x_vals = sorted(set(x_vals + extra_x_vals))
173
+ y_vals1 = [self(x) for x in x_vals]
174
+ y_vals2 = [p(x) for x in x_vals]
175
+ return (
176
+ PiecewiseLinear(*zip(x_vals, y_vals1)),
177
+ PiecewiseLinear(*zip(x_vals, y_vals2)),
178
+ )
179
+
180
+
181
+ class ScheduledFloat(torch.nn.Module):
182
+ """
183
+ This object is a torch.nn.Module only because we want it to show up in
184
+ [top_level module].modules(); it does not have a working forward() function.
185
+ You are supposed to cast it to float, as in, float(parent_module.whatever), and use
186
+ it as something like a dropout prob.
187
+
188
+ It is a floating point value whose value changes depending on the batch count of the
189
+ training loop. It is a piecewise linear function where you specify the (x,y) pairs
190
+ in sorted order on x; x corresponds to the batch index. For batch-index values
191
+ before the first x or after the last x, we just use the first or last y value.
192
+
193
+ Example:
194
+ self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0)
195
+
196
+ `default` is used when self.batch_count is not set or not in training mode or in
197
+ torch.jit scripting mode.
198
+ """
199
+
200
+ def __init__(self, *args, default: float = 0.0):
201
+ super().__init__()
202
+ # self.batch_count and self.name will be written to in the training loop.
203
+ self.batch_count = None
204
+ self.name = None
205
+ self.default = default
206
+ self.schedule = PiecewiseLinear(*args)
207
+
208
+ def extra_repr(self) -> str:
209
+ return (
210
+ f"batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}"
211
+ )
212
+
213
+ def __float__(self):
214
+ batch_count = self.batch_count
215
+ if (
216
+ batch_count is None
217
+ or not self.training
218
+ or torch.jit.is_scripting()
219
+ or torch.jit.is_tracing()
220
+ ):
221
+ return float(self.default)
222
+ else:
223
+ ans = self.schedule(self.batch_count)
224
+ if random.random() < 0.0002:
225
+ logging.debug(
226
+ f"ScheduledFloat: name={self.name}, "
227
+ f"batch_count={self.batch_count}, ans={ans}"
228
+ )
229
+ return ans
230
+
231
+ def __add__(self, x):
232
+ if isinstance(x, float) or isinstance(x, int):
233
+ return ScheduledFloat(self.schedule + x, default=self.default)
234
+ else:
235
+ return ScheduledFloat(
236
+ self.schedule + x.schedule, default=self.default + x.default
237
+ )
238
+
239
+ def max(self, x):
240
+ if isinstance(x, float) or isinstance(x, int):
241
+ return ScheduledFloat(self.schedule.max(x), default=self.default)
242
+ else:
243
+ return ScheduledFloat(
244
+ self.schedule.max(x.schedule),
245
+ default=max(self.default, x.default),
246
+ )
247
+
248
+
249
+ FloatLike = Union[float, ScheduledFloat]
250
+
251
+
252
+ class CutoffEstimator:
253
+ """
254
+ Estimates cutoffs of an arbitrary numerical quantity such that a specified
255
+ proportion of items will be above the cutoff on average.
256
+
257
+ p is the proportion of items that should be above the cutoff.
258
+ """
259
+
260
+ def __init__(self, p: float):
261
+ self.p = p
262
+ # total count of items
263
+ self.count = 0
264
+ # total count of items that were above the cutoff
265
+ self.count_above = 0
266
+ # initial cutoff value
267
+ self.cutoff = 0
268
+
269
+ def __call__(self, x: float) -> bool:
270
+ """
271
+ Returns true if x is above the cutoff.
272
+ """
273
+ ans = x > self.cutoff
274
+ self.count += 1
275
+ if ans:
276
+ self.count_above += 1
277
+ cur_p = self.count_above / self.count
278
+ delta_p = cur_p - self.p
279
+ if (delta_p > 0) == ans:
280
+ q = abs(delta_p)
281
+ self.cutoff = x * q + self.cutoff * (1 - q)
282
+ return ans
283
+
284
+
285
+ class SoftmaxFunction(torch.autograd.Function):
286
+ """
287
+ Tries to handle half-precision derivatives in a randomized way that should
288
+ be more accurate for training than the default behavior.
289
+ """
290
+
291
+ @staticmethod
292
+ def forward(ctx, x: Tensor, dim: int):
293
+ ans = x.softmax(dim=dim)
294
+ # if x dtype is float16, x.softmax() returns a float32 because
295
+ # (presumably) that op does not support float16, and autocast
296
+ # is enabled.
297
+ if torch.is_autocast_enabled():
298
+ ans = ans.to(torch.float16)
299
+ ctx.save_for_backward(ans)
300
+ ctx.x_dtype = x.dtype
301
+ ctx.dim = dim
302
+ return ans
303
+
304
+ @staticmethod
305
+ def backward(ctx, ans_grad: Tensor):
306
+ (ans,) = ctx.saved_tensors
307
+ with torch.amp.autocast("cuda", enabled=False):
308
+ ans_grad = ans_grad.to(torch.float32)
309
+ ans = ans.to(torch.float32)
310
+ x_grad = ans_grad * ans
311
+ x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
312
+ return x_grad, None
313
+
314
+
315
+ def softmax(x: Tensor, dim: int):
316
+ if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing():
317
+ return x.softmax(dim=dim)
318
+
319
+ return SoftmaxFunction.apply(x, dim)
320
+
321
+
322
+ class BiasNormFunction(torch.autograd.Function):
323
+ # This computes:
324
+ # scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp()
325
+ # return x * scales
326
+ # (after unsqueezing the bias), but it does it in a memory-efficient way so that
327
+ # it can just store the returned value (chances are, this will also be needed for
328
+ # some other reason, related to the next operation, so we can save memory).
329
+ @staticmethod
330
+ def forward(
331
+ ctx,
332
+ x: Tensor,
333
+ bias: Tensor,
334
+ log_scale: Tensor,
335
+ channel_dim: int,
336
+ store_output_for_backprop: bool,
337
+ ) -> Tensor:
338
+ assert bias.ndim == 1
339
+ if channel_dim < 0:
340
+ channel_dim = channel_dim + x.ndim
341
+ ctx.store_output_for_backprop = store_output_for_backprop
342
+ ctx.channel_dim = channel_dim
343
+ for _ in range(channel_dim + 1, x.ndim):
344
+ bias = bias.unsqueeze(-1)
345
+ scales = (
346
+ torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5
347
+ ) * log_scale.exp()
348
+ ans = x * scales
349
+ ctx.save_for_backward(
350
+ ans.detach() if store_output_for_backprop else x,
351
+ scales.detach(),
352
+ bias.detach(),
353
+ log_scale.detach(),
354
+ )
355
+ return ans
356
+
357
+ @staticmethod
358
+ def backward(ctx, ans_grad: Tensor) -> Tensor:
359
+ ans_or_x, scales, bias, log_scale = ctx.saved_tensors
360
+ if ctx.store_output_for_backprop:
361
+ x = ans_or_x / scales
362
+ else:
363
+ x = ans_or_x
364
+ x = x.detach()
365
+ x.requires_grad = True
366
+ bias.requires_grad = True
367
+ log_scale.requires_grad = True
368
+ with torch.enable_grad():
369
+ # recompute scales from x, bias and log_scale.
370
+ scales = (
371
+ torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5
372
+ ) * log_scale.exp()
373
+ ans = x * scales
374
+ ans.backward(gradient=ans_grad)
375
+ return x.grad, bias.grad.flatten(), log_scale.grad, None, None
376
+
377
+
378
+ class BiasNorm(torch.nn.Module):
379
+ """
380
+ This is intended to be a simpler, and hopefully cheaper, replacement for
381
+ LayerNorm. The observation this is based on, is that Transformer-type
382
+ networks, especially with pre-norm, sometimes seem to set one of the
383
+ feature dimensions to a large constant value (e.g. 50), which "defeats"
384
+ the LayerNorm because the output magnitude is then not strongly dependent
385
+ on the other (useful) features. Presumably the weight and bias of the
386
+ LayerNorm are required to allow it to do this.
387
+
388
+ Instead, we give the BiasNorm a trainable bias that it can use when
389
+ computing the scale for normalization. We also give it a (scalar)
390
+ trainable scale on the output.
391
+
392
+
393
+ Args:
394
+ num_channels: the number of channels, e.g. 512.
395
+ channel_dim: the axis/dimension corresponding to the channel,
396
+ interpreted as an offset from the input's ndim if negative.
397
+ This is NOT the num_channels; it should typically be one of
398
+ {-2, -1, 0, 1, 2, 3}.
399
+ log_scale: the initial log-scale that we multiply the output by; this
400
+ is learnable.
401
+ log_scale_min: FloatLike, minimum allowed value of log_scale
402
+ log_scale_max: FloatLike, maximum allowed value of log_scale
403
+ store_output_for_backprop: only possibly affects memory use; recommend
404
+ to set to True if you think the output of this module is more likely
405
+ than the input of this module to be required to be stored for the
406
+ backprop.
407
+ """
408
+
409
+ def __init__(
410
+ self,
411
+ num_channels: int,
412
+ channel_dim: int = -1, # CAUTION: see documentation.
413
+ log_scale: float = 1.0,
414
+ log_scale_min: float = -1.5,
415
+ log_scale_max: float = 1.5,
416
+ store_output_for_backprop: bool = False,
417
+ ) -> None:
418
+ super(BiasNorm, self).__init__()
419
+ self.num_channels = num_channels
420
+ self.channel_dim = channel_dim
421
+ self.log_scale = nn.Parameter(torch.tensor(log_scale))
422
+ self.bias = nn.Parameter(torch.zeros(num_channels))
423
+
424
+ self.log_scale_min = log_scale_min
425
+ self.log_scale_max = log_scale_max
426
+
427
+ self.store_output_for_backprop = store_output_for_backprop
428
+
429
+ def forward(self, x: Tensor) -> Tensor:
430
+ assert x.shape[self.channel_dim] == self.num_channels
431
+
432
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
433
+ channel_dim = self.channel_dim
434
+ if channel_dim < 0:
435
+ channel_dim += x.ndim
436
+ bias = self.bias
437
+ for _ in range(channel_dim + 1, x.ndim):
438
+ bias = bias.unsqueeze(-1)
439
+ scales = (
440
+ torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5
441
+ ) * self.log_scale.exp()
442
+ return x * scales
443
+
444
+ log_scale = limit_param_value(
445
+ self.log_scale,
446
+ min=float(self.log_scale_min),
447
+ max=float(self.log_scale_max),
448
+ training=self.training,
449
+ )
450
+
451
+ return BiasNormFunction.apply(
452
+ x,
453
+ self.bias,
454
+ log_scale,
455
+ self.channel_dim,
456
+ self.store_output_for_backprop,
457
+ )
458
+
459
+
460
+ def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
461
+ """
462
+ Behaves like a constructor of a modified version of nn.Linear
463
+ that gives an easy way to set the default initial parameter scale.
464
+
465
+ Args:
466
+ Accepts the standard args and kwargs that nn.Linear accepts
467
+ e.g. in_features, out_features, bias=False.
468
+
469
+ initial_scale: you can override this if you want to increase
470
+ or decrease the initial magnitude of the module's output
471
+ (affects the initialization of weight_scale and bias_scale).
472
+ Another option, if you want to do something like this, is
473
+ to re-initialize the parameters.
474
+ """
475
+ ans = nn.Linear(*args, **kwargs)
476
+ with torch.no_grad():
477
+ ans.weight[:] *= initial_scale
478
+ if ans.bias is not None:
479
+ torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
480
+ return ans
481
+
482
+
483
+ class BalancerFunction(torch.autograd.Function):
484
+ @staticmethod
485
+ def forward(
486
+ ctx,
487
+ x: Tensor,
488
+ min_mean: float,
489
+ max_mean: float,
490
+ min_rms: float,
491
+ max_rms: float,
492
+ grad_scale: float,
493
+ channel_dim: int,
494
+ ) -> Tensor:
495
+ if channel_dim < 0:
496
+ channel_dim += x.ndim
497
+ ctx.channel_dim = channel_dim
498
+ ctx.save_for_backward(x)
499
+ ctx.config = (
500
+ min_mean,
501
+ max_mean,
502
+ min_rms,
503
+ max_rms,
504
+ grad_scale,
505
+ channel_dim,
506
+ )
507
+ return x
508
+
509
+ @staticmethod
510
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]:
511
+ (x,) = ctx.saved_tensors
512
+ (
513
+ min_mean,
514
+ max_mean,
515
+ min_rms,
516
+ max_rms,
517
+ grad_scale,
518
+ channel_dim,
519
+ ) = ctx.config
520
+
521
+ try:
522
+ with torch.enable_grad():
523
+ with torch.amp.autocast("cuda", enabled=False):
524
+ x = x.to(torch.float32)
525
+ x = x.detach()
526
+ x.requires_grad = True
527
+ mean_dims = [i for i in range(x.ndim) if i != channel_dim]
528
+ uncentered_var = (x**2).mean(dim=mean_dims, keepdim=True)
529
+ mean = x.mean(dim=mean_dims, keepdim=True)
530
+ stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt()
531
+ rms = uncentered_var.clamp(min=1.0e-20).sqrt()
532
+
533
+ m = mean / stddev
534
+ # part of loss that relates to mean / stddev
535
+ m_loss = (m - m.clamp(min=min_mean, max=max_mean)).abs()
536
+
537
+ # put a much larger scale on the RMS-max-limit loss, so that if both
538
+ # it and the m_loss are violated we fix the RMS loss first.
539
+ rms_clamped = rms.clamp(min=min_rms, max=max_rms)
540
+ r_loss = (rms_clamped / rms).log().abs()
541
+
542
+ loss = m_loss + r_loss
543
+
544
+ loss.backward(gradient=torch.ones_like(loss))
545
+ loss_grad = x.grad
546
+ loss_grad_rms = (
547
+ (loss_grad**2)
548
+ .mean(dim=mean_dims, keepdim=True)
549
+ .sqrt()
550
+ .clamp(min=1.0e-20)
551
+ )
552
+
553
+ loss_grad = loss_grad * (grad_scale / loss_grad_rms)
554
+
555
+ x_grad_float = x_grad.to(torch.float32)
556
+ # scale each element of loss_grad by the absolute value of the
557
+ # corresponding element of x_grad, which we view as a noisy estimate
558
+ # of its magnitude for that (frame and dimension). later we can
559
+ # consider factored versions.
560
+ x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad)
561
+ x_grad = x_grad_mod.to(x_grad.dtype)
562
+ except Exception as e:
563
+ logging.info(
564
+ f"Caught exception in Balancer backward: {e}, "
565
+ f"size={list(x_grad.shape)}, will continue."
566
+ )
567
+
568
+ return x_grad, None, None, None, None, None, None
569
+
570
+
571
+ class Balancer(torch.nn.Module):
572
+ """
573
+ Modifies the backpropped derivatives of a function to try to encourage, for
574
+ each channel, that it is positive at least a proportion `threshold` of the
575
+ time. It does this by multiplying negative derivative values by up to
576
+ (1+max_factor), and positive derivative values by up to (1-max_factor),
577
+ interpolated from 1 at the threshold to those extremal values when none
578
+ of the inputs are positive.
579
+
580
+ Args:
581
+ num_channels: the number of channels
582
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
583
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
584
+ min_positive: the minimum, per channel, of the proportion of the time
585
+ that (x > 0), below which we start to modify the derivatives.
586
+ max_positive: the maximum, per channel, of the proportion of the time
587
+ that (x > 0), above which we start to modify the derivatives.
588
+ scale_gain_factor: determines the 'gain' with which we increase the
589
+ change in gradient once the constraints on min_abs and max_abs
590
+ are violated.
591
+ min_abs: the minimum average-absolute-value difference from the mean
592
+ value per channel, which we allow, before we start to modify
593
+ the derivatives to prevent this.
594
+ max_abs: the maximum average-absolute-value difference from the mean
595
+ value per channel, which we allow, before we start to modify
596
+ the derivatives to prevent this.
597
+ prob: determines the minimum probability with which we modify the
598
+ gradients for the {min,max}_positive and {min,max}_abs constraints,
599
+ on each forward(). This is done randomly to prevent all layers
600
+ from doing it at the same time.
601
+ """
602
+
603
+ def __init__(
604
+ self,
605
+ num_channels: int,
606
+ channel_dim: int,
607
+ min_positive: FloatLike = 0.05,
608
+ max_positive: FloatLike = 0.95,
609
+ min_abs: FloatLike = 0.2,
610
+ max_abs: FloatLike = 100.0,
611
+ grad_scale: FloatLike = 0.04,
612
+ prob: Optional[FloatLike] = None,
613
+ ):
614
+ super().__init__()
615
+
616
+ if prob is None:
617
+ prob = ScheduledFloat((0.0, 0.5), (8000.0, 0.125), default=0.4)
618
+ self.prob = prob
619
+ # 5% of the time we will return and do nothing because memory usage is
620
+ # too high.
621
+ self.mem_cutoff = CutoffEstimator(0.05)
622
+
623
+ # actually self.num_channels is no longer needed except for an assertion.
624
+ self.num_channels = num_channels
625
+ self.channel_dim = channel_dim
626
+ self.min_positive = min_positive
627
+ self.max_positive = max_positive
628
+ self.min_abs = min_abs
629
+ self.max_abs = max_abs
630
+ self.grad_scale = grad_scale
631
+
632
+ def forward(self, x: Tensor) -> Tensor:
633
+ if (
634
+ torch.jit.is_scripting()
635
+ or not x.requires_grad
636
+ or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))
637
+ ):
638
+ return _no_op(x)
639
+
640
+ prob = float(self.prob)
641
+ if random.random() < prob:
642
+ # The following inner-functions convert from the way we historically
643
+ # specified these limitations, as limits on the absolute value and the
644
+ # proportion of positive values, to limits on the RMS value and
645
+ # the (mean / stddev).
646
+ def _abs_to_rms(x):
647
+ # for normally distributed data, if the expected absolute value is x,
648
+ # the expected rms value will be sqrt(pi/2) * x.
649
+ return 1.25331413732 * x
650
+
651
+ def _proportion_positive_to_mean(x):
652
+ def _atanh(x):
653
+ eps = 1.0e-10
654
+ # eps is to prevent crashes if x is exactly 0 or 1.
655
+ # we'll just end up returning a fairly large value.
656
+ return (math.log(1 + x + eps) - math.log(1 - x + eps)) / 2.0
657
+
658
+ def _approx_inverse_erf(x):
659
+ # 1 / (sqrt(pi) * ln(2)),
660
+ # see https://math.stackexchange.com/questions/321569/
661
+ # approximating-the-error-function-erf-by-analytical-functions
662
+ # this approximation is extremely crude and gets progressively worse
663
+ # for x very close to -1 or +1, but we mostly care about the
664
+ # "middle" region
665
+ # e.g. _approx_inverse_erf(0.05) = 0.0407316414078772,
666
+ # and math.erf(0.0407316414078772) = 0.045935330944660666,
667
+ # which is pretty close to 0.05.
668
+ return 0.8139535143 * _atanh(x)
669
+
670
+ # first convert x from the range 0..1 to the range -1..1 which the error
671
+ # function returns
672
+ x = -1 + (2 * x)
673
+ return _approx_inverse_erf(x)
674
+
675
+ min_mean = _proportion_positive_to_mean(float(self.min_positive))
676
+ max_mean = _proportion_positive_to_mean(float(self.max_positive))
677
+ min_rms = _abs_to_rms(float(self.min_abs))
678
+ max_rms = _abs_to_rms(float(self.max_abs))
679
+ grad_scale = float(self.grad_scale)
680
+
681
+ assert x.shape[self.channel_dim] == self.num_channels
682
+
683
+ return BalancerFunction.apply(
684
+ x,
685
+ min_mean,
686
+ max_mean,
687
+ min_rms,
688
+ max_rms,
689
+ grad_scale,
690
+ self.channel_dim,
691
+ )
692
+ else:
693
+ return _no_op(x)
694
+
695
+
696
+ def penalize_abs_values_gt(
697
+ x: Tensor, limit: float, penalty: float, name: str = None
698
+ ) -> Tensor:
699
+ """
700
+ Returns x unmodified, but in backprop will put a penalty for the excess of
701
+ the absolute values of elements of x over the limit "limit". E.g. if
702
+ limit == 10.0, then if x has any values over 10 it will get a penalty.
703
+
704
+ Caution: the value of this penalty will be affected by grad scaling used
705
+ in automatic mixed precision training. For this reasons we use this,
706
+ it shouldn't really matter, or may even be helpful; we just use this
707
+ to disallow really implausible values of scores to be given to softmax.
708
+
709
+ The name is for randomly printed debug info.
710
+ """
711
+ x_sign = x.sign()
712
+ over_limit = (x.abs() - limit) > 0
713
+ # The following is a memory efficient way to penalize the absolute values of
714
+ # x that's over the limit. (The memory efficiency comes when you think
715
+ # about which items torch needs to cache for the autograd, and which ones it
716
+ # can throw away). The numerical value of aux_loss as computed here will
717
+ # actually be larger than it should be, by limit * over_limit.sum(), but it
718
+ # has the same derivative as the real aux_loss which is penalty * (x.abs() -
719
+ # limit).relu().
720
+ aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
721
+ # note: we don't do sum() here on aux)_loss, but it's as if we had done
722
+ # sum() due to how with_loss() works.
723
+ x = with_loss(x, aux_loss, name)
724
+ # you must use x for something, or this will be ineffective.
725
+ return x
726
+
727
+
728
+ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
729
+ if x.ndim == 2:
730
+ return x.diag()
731
+ else:
732
+ (batch, dim, dim) = x.shape
733
+ x = x.reshape(batch, dim * dim)
734
+ x = x[:, :: dim + 1]
735
+ assert x.shape == (batch, dim)
736
+ return x
737
+
738
+
739
+ def _whitening_metric(x: Tensor, num_groups: int):
740
+ """
741
+ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of
742
+ of the centered feature covariance are the same within each group's covariance
743
+ matrix and also between groups.
744
+ Args:
745
+ x: a Tensor of shape (*, num_channels)
746
+ num_groups: the number of groups of channels, a number >=1 that divides
747
+ num_channels
748
+ Returns:
749
+ Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and
750
+ greater than 1.0 otherwise.
751
+ """
752
+ assert x.dtype != torch.float16
753
+ x = x.reshape(-1, x.shape[-1])
754
+ (num_frames, num_channels) = x.shape
755
+ assert num_channels % num_groups == 0
756
+ channels_per_group = num_channels // num_groups
757
+ x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1)
758
+ # x now has shape (num_groups, num_frames, channels_per_group)
759
+ # subtract the mean so we use the centered, not uncentered, covariance.
760
+ # My experience has been that when we "mess with the gradients" like this,
761
+ # it's better not do anything that tries to move the mean around, because
762
+ # that can easily cause instability.
763
+ x = x - x.mean(dim=1, keepdim=True)
764
+ # x_covar: (num_groups, channels_per_group, channels_per_group)
765
+ x_covar = torch.matmul(x.transpose(1, 2), x)
766
+ x_covar_mean_diag = _diag(x_covar).mean()
767
+ # the following expression is what we'd get if we took the matrix product
768
+ # of each covariance and measured the mean of its trace, i.e.
769
+ # the same as _diag(torch.matmul(x_covar, x_covar)).mean().
770
+ x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group)
771
+ # this metric will be >= 1.0; the larger it is, the less 'white' the data was.
772
+ metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20)
773
+ return metric
774
+
775
+
776
+ class WhiteningPenaltyFunction(torch.autograd.Function):
777
+ @staticmethod
778
+ def forward(ctx, x: Tensor, module: nn.Module) -> Tensor:
779
+ ctx.save_for_backward(x)
780
+ ctx.module = module
781
+ return x
782
+
783
+ @staticmethod
784
+ def backward(ctx, x_grad: Tensor):
785
+ (x_orig,) = ctx.saved_tensors
786
+ w = ctx.module
787
+
788
+ try:
789
+ with torch.enable_grad():
790
+ with torch.amp.autocast("cuda", enabled=False):
791
+ x_detached = x_orig.to(torch.float32).detach()
792
+ x_detached.requires_grad = True
793
+
794
+ metric = _whitening_metric(x_detached, w.num_groups)
795
+
796
+ if random.random() < 0.005 or __name__ == "__main__":
797
+ logging.debug(
798
+ f"Whitening: name={w.name}, num_groups={w.num_groups},"
799
+ f"num_channels={x_orig.shape[-1]}, "
800
+ f"metric={metric.item():.2f}"
801
+ f" vs. limit={float(w.whitening_limit)}"
802
+ )
803
+
804
+ if metric < float(w.whitening_limit):
805
+ w.prob = w.min_prob
806
+ return x_grad, None
807
+ else:
808
+ w.prob = w.max_prob
809
+ metric.backward()
810
+ penalty_grad = x_detached.grad
811
+ scale = w.grad_scale * (
812
+ x_grad.to(torch.float32).norm()
813
+ / (penalty_grad.norm() + 1.0e-20)
814
+ )
815
+ penalty_grad = penalty_grad * scale
816
+ return x_grad + penalty_grad.to(x_grad.dtype), None
817
+ except Exception as e:
818
+ logging.info(
819
+ f"Caught exception in Whiten backward: {e}, "
820
+ f"size={list(x_grad.shape)}, will continue."
821
+ )
822
+ return x_grad, None
823
+
824
+
825
+ class Whiten(nn.Module):
826
+ def __init__(
827
+ self,
828
+ num_groups: int,
829
+ whitening_limit: FloatLike,
830
+ prob: Union[float, Tuple[float, float]],
831
+ grad_scale: FloatLike,
832
+ ):
833
+ """
834
+ Args:
835
+ num_groups: the number of groups to divide the channel dim into before
836
+ whitening. We will attempt to make the feature covariance
837
+ within each group, after mean subtraction, as "white" as possible,
838
+ while having the same trace across all groups.
839
+ whitening_limit: a value greater than 1.0, that dictates how much
840
+ freedom we have to violate the constraints. 1.0 would mean perfectly
841
+ white, with exactly the same trace across groups; larger values
842
+ give more freedom. E.g. 2.0.
843
+ prob: the probability with which we apply the gradient modification
844
+ (also affects the grad scale). May be supplied as a float,
845
+ or as a pair (min_prob, max_prob)
846
+
847
+ grad_scale: determines the scale on the gradient term from this object,
848
+ relative to the rest of the gradient on the attention weights.
849
+ E.g. 0.02 (you may want to use smaller values than this if prob is large)
850
+ """
851
+ super(Whiten, self).__init__()
852
+ assert num_groups >= 1
853
+ assert float(whitening_limit) >= 1
854
+ assert grad_scale >= 0
855
+ self.num_groups = num_groups
856
+ self.whitening_limit = whitening_limit
857
+ self.grad_scale = grad_scale
858
+
859
+ if isinstance(prob, float):
860
+ prob = (prob, prob)
861
+ (self.min_prob, self.max_prob) = prob
862
+ assert 0 < self.min_prob <= self.max_prob <= 1
863
+ self.prob = self.max_prob
864
+ self.name = None # will be set in training loop
865
+
866
+ def forward(self, x: Tensor) -> Tensor:
867
+ """
868
+ In the forward pass, this function just returns the input unmodified.
869
+ In the backward pass, it will modify the gradients to ensure that the
870
+ distribution in each group has close to (lambda times I) as the covariance
871
+ after mean subtraction, with the same lambda across groups.
872
+ For whitening_limit > 1, there will be more freedom to violate this
873
+ constraint.
874
+
875
+ Args:
876
+ x: the input of shape (*, num_channels)
877
+
878
+ Returns:
879
+ x, unmodified. You should make sure
880
+ you use the returned value, or the graph will be freed
881
+ and nothing will happen in backprop.
882
+ """
883
+ grad_scale = float(self.grad_scale)
884
+ if not x.requires_grad or random.random() > self.prob or grad_scale == 0:
885
+ return _no_op(x)
886
+ else:
887
+ return WhiteningPenaltyFunction.apply(x, self)
888
+
889
+
890
+ class WithLoss(torch.autograd.Function):
891
+ @staticmethod
892
+ def forward(ctx, x: Tensor, y: Tensor, name: str):
893
+ ctx.y_shape = y.shape
894
+ if random.random() < 0.002 and name is not None:
895
+ loss_sum = y.sum().item()
896
+ logging.debug(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}")
897
+ return x
898
+
899
+ @staticmethod
900
+ def backward(ctx, ans_grad: Tensor):
901
+ return (
902
+ ans_grad,
903
+ torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device),
904
+ None,
905
+ )
906
+
907
+
908
+ def with_loss(x, y, name):
909
+ # returns x but adds y.sum() to the loss function.
910
+ return WithLoss.apply(x, y, name)
911
+
912
+
913
+ class LimitParamValue(torch.autograd.Function):
914
+ @staticmethod
915
+ def forward(ctx, x: Tensor, min: float, max: float):
916
+ ctx.save_for_backward(x)
917
+ assert max >= min
918
+ ctx.min = min
919
+ ctx.max = max
920
+ return x
921
+
922
+ @staticmethod
923
+ def backward(ctx, x_grad: Tensor):
924
+ (x,) = ctx.saved_tensors
925
+ # where x < ctx.min, ensure all grads are negative (this will tend to make
926
+ # x more positive).
927
+ x_grad = x_grad * torch.where(
928
+ torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0
929
+ )
930
+ # where x > ctx.max, ensure all grads are positive (this will tend to make
931
+ # x more negative).
932
+ x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0)
933
+ return x_grad, None, None
934
+
935
+
936
+ def limit_param_value(
937
+ x: Tensor, min: float, max: float, prob: float = 0.6, training: bool = True
938
+ ):
939
+ # You apply this to (typically) an nn.Parameter during training to ensure that its
940
+ # (elements mostly) stays within a supplied range. This is done by modifying the
941
+ # gradients in backprop.
942
+ # It's not necessary to do this on every batch: do it only some of the time,
943
+ # to save a little time.
944
+ if training and random.random() < prob:
945
+ return LimitParamValue.apply(x, min, max)
946
+ else:
947
+ return x
948
+
949
+
950
+ def _no_op(x: Tensor) -> Tensor:
951
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
952
+ return x
953
+ else:
954
+ # a no-op function that will have a node in the autograd graph,
955
+ # to avoid certain bugs relating to backward hooks
956
+ return x.chunk(1, dim=-1)[0]
957
+
958
+
959
+ # Identity more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
960
+ class Identity(torch.nn.Module):
961
+ def __init__(self):
962
+ super(Identity, self).__init__()
963
+
964
+ def forward(self, x):
965
+ return _no_op(x)
966
+
967
+
968
+ # Dropout2 is just like normal dropout, except it supports schedules
969
+ # on the dropout rates.
970
+ class Dropout2(nn.Module):
971
+ def __init__(self, p: FloatLike):
972
+ super().__init__()
973
+ self.p = p
974
+
975
+ def forward(self, x: Tensor) -> Tensor:
976
+ return torch.nn.functional.dropout(x, p=float(self.p), training=self.training)
977
+
978
+
979
+ class MulForDropout3(torch.autograd.Function):
980
+ # returns (x * y * alpha) where alpha is a float and y doesn't require
981
+ # grad and is zero-or-one.
982
+ @staticmethod
983
+ @custom_fwd
984
+ def forward(ctx, x, y, alpha):
985
+ assert not y.requires_grad
986
+ ans = x * y * alpha
987
+ ctx.save_for_backward(ans)
988
+ ctx.alpha = alpha
989
+ return ans
990
+
991
+ @staticmethod
992
+ @custom_bwd
993
+ def backward(ctx, ans_grad):
994
+ (ans,) = ctx.saved_tensors
995
+ x_grad = ctx.alpha * ans_grad * (ans != 0)
996
+ return x_grad, None, None
997
+
998
+
999
+ # Dropout3 is just like normal dropout, except it supports schedules on the dropout
1000
+ # rates, and it lets you choose one dimension to share the dropout mask over
1001
+ class Dropout3(nn.Module):
1002
+ def __init__(self, p: FloatLike, shared_dim: int):
1003
+ super().__init__()
1004
+ self.p = p
1005
+ self.shared_dim = shared_dim
1006
+
1007
+ def forward(self, x: Tensor) -> Tensor:
1008
+ p = float(self.p)
1009
+ if not self.training or p == 0:
1010
+ return _no_op(x)
1011
+ scale = 1.0 / (1 - p)
1012
+ rand_shape = list(x.shape)
1013
+ rand_shape[self.shared_dim] = 1
1014
+ mask = torch.rand(*rand_shape, device=x.device) > p
1015
+ ans = MulForDropout3.apply(x, mask, scale)
1016
+ return ans
1017
+
1018
+
1019
+ class SwooshLFunction(torch.autograd.Function):
1020
+ """
1021
+ swoosh_l(x) = log(1 + exp(x-4)) - 0.08*x - 0.035
1022
+ """
1023
+
1024
+ @staticmethod
1025
+ def forward(ctx, x: Tensor) -> Tensor:
1026
+ requires_grad = x.requires_grad
1027
+ if x.dtype == torch.float16:
1028
+ x = x.to(torch.float32)
1029
+
1030
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
1031
+
1032
+ coeff = -0.08
1033
+
1034
+ with torch.amp.autocast("cuda", enabled=False):
1035
+ with torch.enable_grad():
1036
+ x = x.detach()
1037
+ x.requires_grad = True
1038
+ y = torch.logaddexp(zero, x - 4.0) + coeff * x - 0.035
1039
+
1040
+ if not requires_grad:
1041
+ return y
1042
+
1043
+ y.backward(gradient=torch.ones_like(y))
1044
+
1045
+ grad = x.grad
1046
+ floor = coeff
1047
+ ceil = 1.0 + coeff + 0.005
1048
+
1049
+ d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
1050
+ grad
1051
+ )
1052
+ if __name__ == "__main__":
1053
+ # for self-testing only.
1054
+ assert d_scaled.min() >= 0.0
1055
+ assert d_scaled.max() < 256.0
1056
+
1057
+ d_int = d_scaled.to(torch.uint8)
1058
+ ctx.save_for_backward(d_int)
1059
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
1060
+ y = y.to(torch.float16)
1061
+ return y
1062
+
1063
+ @staticmethod
1064
+ def backward(ctx, y_grad: Tensor) -> Tensor:
1065
+ (d,) = ctx.saved_tensors
1066
+ # the same constants as used in forward pass.
1067
+ coeff = -0.08
1068
+ floor = coeff
1069
+ ceil = 1.0 + coeff + 0.005
1070
+ d = d * ((ceil - floor) / 255.0) + floor
1071
+ return y_grad * d
1072
+
1073
+
1074
+ class SwooshL(torch.nn.Module):
1075
+ def forward(self, x: Tensor) -> Tensor:
1076
+ """Return Swoosh-L activation."""
1077
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
1078
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
1079
+ return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
1080
+ return SwooshLFunction.apply(x)
1081
+
1082
+
1083
+ class SwooshLOnnx(torch.nn.Module):
1084
+ def forward(self, x: Tensor) -> Tensor:
1085
+ """Return Swoosh-L activation."""
1086
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
1087
+ return logaddexp_onnx(zero, x - 4.0) - 0.08 * x - 0.035
1088
+
1089
+
1090
+ class SwooshRFunction(torch.autograd.Function):
1091
+ """
1092
+ swoosh_r(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687
1093
+
1094
+ derivatives are between -0.08 and 0.92.
1095
+ """
1096
+
1097
+ @staticmethod
1098
+ def forward(ctx, x: Tensor) -> Tensor:
1099
+ requires_grad = x.requires_grad
1100
+
1101
+ if x.dtype == torch.float16:
1102
+ x = x.to(torch.float32)
1103
+
1104
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
1105
+
1106
+ with torch.amp.autocast("cuda", enabled=False):
1107
+ with torch.enable_grad():
1108
+ x = x.detach()
1109
+ x.requires_grad = True
1110
+ y = torch.logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687
1111
+
1112
+ if not requires_grad:
1113
+ return y
1114
+ y.backward(gradient=torch.ones_like(y))
1115
+
1116
+ grad = x.grad
1117
+ floor = -0.08
1118
+ ceil = 0.925
1119
+
1120
+ d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
1121
+ grad
1122
+ )
1123
+ if __name__ == "__main__":
1124
+ # for self-testing only.
1125
+ assert d_scaled.min() >= 0.0
1126
+ assert d_scaled.max() < 256.0
1127
+
1128
+ d_int = d_scaled.to(torch.uint8)
1129
+ ctx.save_for_backward(d_int)
1130
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
1131
+ y = y.to(torch.float16)
1132
+ return y
1133
+
1134
+ @staticmethod
1135
+ def backward(ctx, y_grad: Tensor) -> Tensor:
1136
+ (d,) = ctx.saved_tensors
1137
+ # the same constants as used in forward pass.
1138
+ floor = -0.08
1139
+ ceil = 0.925
1140
+ d = d * ((ceil - floor) / 255.0) + floor
1141
+ return y_grad * d
1142
+
1143
+
1144
+ class SwooshR(torch.nn.Module):
1145
+ def forward(self, x: Tensor) -> Tensor:
1146
+ """Return Swoosh-R activation."""
1147
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
1148
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
1149
+ return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687
1150
+ return SwooshRFunction.apply(x)
1151
+
1152
+
1153
+ class SwooshROnnx(torch.nn.Module):
1154
+ def forward(self, x: Tensor) -> Tensor:
1155
+ """Return Swoosh-R activation."""
1156
+ zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
1157
+ return logaddexp_onnx(zero, x - 1.0) - 0.08 * x - 0.313261687
1158
+
1159
+
1160
+ # simple version of SwooshL that does not redefine the backprop, used in
1161
+ # ActivationDropoutAndLinearFunction.
1162
+ def SwooshLForward(x: Tensor):
1163
+ with torch.amp.autocast("cuda", enabled=False):
1164
+ x = x.to(torch.float32)
1165
+ x_offset = x - 4.0
1166
+ log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
1167
+ log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
1168
+ return log_sum - 0.08 * x - 0.035
1169
+
1170
+
1171
+ # simple version of SwooshR that does not redefine the backprop, used in
1172
+ # ActivationDropoutAndLinearFunction.
1173
+ def SwooshRForward(x: Tensor):
1174
+ with torch.amp.autocast("cuda", enabled=False):
1175
+ x = x.to(torch.float32)
1176
+ x_offset = x - 1.0
1177
+ log_sum = (1.0 + x_offset.exp()).log().to(x.dtype)
1178
+ log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum)
1179
+ return log_sum - 0.08 * x - 0.313261687
1180
+
1181
+
1182
+ class ActivationDropoutAndLinearFunction(torch.autograd.Function):
1183
+ @staticmethod
1184
+ @custom_fwd
1185
+ def forward(
1186
+ ctx,
1187
+ x: Tensor,
1188
+ weight: Tensor,
1189
+ bias: Optional[Tensor],
1190
+ activation: str,
1191
+ dropout_p: float,
1192
+ dropout_shared_dim: Optional[int],
1193
+ ):
1194
+ if dropout_p != 0.0:
1195
+ dropout_shape = list(x.shape)
1196
+ if dropout_shared_dim is not None:
1197
+ dropout_shape[dropout_shared_dim] = 1
1198
+ # else it won't be very memory efficient.
1199
+ dropout_mask = (1.0 / (1.0 - dropout_p)) * (
1200
+ torch.rand(*dropout_shape, device=x.device, dtype=x.dtype) > dropout_p
1201
+ )
1202
+ else:
1203
+ dropout_mask = None
1204
+
1205
+ ctx.save_for_backward(x, weight, bias, dropout_mask)
1206
+
1207
+ ctx.activation = activation
1208
+
1209
+ forward_activation_dict = {
1210
+ "SwooshL": k2.swoosh_l_forward,
1211
+ "SwooshR": k2.swoosh_r_forward,
1212
+ }
1213
+ # it will raise a KeyError if this fails. This will be an error. We let it
1214
+ # propagate to the user.
1215
+ activation_func = forward_activation_dict[activation]
1216
+ x = activation_func(x)
1217
+ if dropout_mask is not None:
1218
+ x = x * dropout_mask
1219
+ x = torch.nn.functional.linear(x, weight, bias)
1220
+ return x
1221
+
1222
+ @staticmethod
1223
+ @custom_bwd
1224
+ def backward(ctx, ans_grad: Tensor):
1225
+ saved = ctx.saved_tensors
1226
+ (x, weight, bias, dropout_mask) = saved
1227
+
1228
+ forward_and_deriv_activation_dict = {
1229
+ "SwooshL": k2.swoosh_l_forward_and_deriv,
1230
+ "SwooshR": k2.swoosh_r_forward_and_deriv,
1231
+ }
1232
+ # the following lines a KeyError if the activation is unrecognized.
1233
+ # This will be an error. We let it propagate to the user.
1234
+ func = forward_and_deriv_activation_dict[ctx.activation]
1235
+
1236
+ y, func_deriv = func(x)
1237
+ if dropout_mask is not None:
1238
+ y = y * dropout_mask
1239
+ # now compute derivative of y w.r.t. weight and bias..
1240
+ # y: (..., in_channels), ans_grad: (..., out_channels),
1241
+ (out_channels, in_channels) = weight.shape
1242
+
1243
+ in_channels = y.shape[-1]
1244
+ g = ans_grad.reshape(-1, out_channels)
1245
+ weight_deriv = torch.matmul(g.t(), y.reshape(-1, in_channels))
1246
+ y_deriv = torch.matmul(ans_grad, weight)
1247
+ bias_deriv = None if bias is None else g.sum(dim=0)
1248
+ x_deriv = y_deriv * func_deriv
1249
+ if dropout_mask is not None:
1250
+ # order versus func_deriv does not matter
1251
+ x_deriv = x_deriv * dropout_mask
1252
+
1253
+ return x_deriv, weight_deriv, bias_deriv, None, None, None
1254
+
1255
+
1256
+ class ActivationDropoutAndLinear(torch.nn.Module):
1257
+ """
1258
+ This merges an activation function followed by dropout and then a nn.Linear module;
1259
+ it does so in a memory efficient way so that it only stores the input to the whole
1260
+ module. If activation == SwooshL and dropout_shared_dim != None, this will be
1261
+ equivalent to:
1262
+ nn.Sequential(SwooshL(),
1263
+ Dropout3(dropout_p, shared_dim=dropout_shared_dim),
1264
+ ScaledLinear(in_channels, out_channels, bias=bias,
1265
+ initial_scale=initial_scale))
1266
+ If dropout_shared_dim is None, the dropout would be equivalent to
1267
+ Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout
1268
+ mask is smaller.
1269
+
1270
+ Args:
1271
+ in_channels: number of input channels, e.g. 256
1272
+ out_channels: number of output channels, e.g. 256
1273
+ bias: if true, have a bias
1274
+ activation: the activation function, for now just support SwooshL.
1275
+ dropout_p: the dropout probability or schedule (happens after nonlinearity).
1276
+ dropout_shared_dim: the dimension, if any, across which the dropout mask is
1277
+ shared (e.g. the time dimension). If None, this may be less memory
1278
+ efficient if there are modules before this one that cache the input
1279
+ for their backprop (e.g. Balancer or Whiten).
1280
+ """
1281
+
1282
+ def __init__(
1283
+ self,
1284
+ in_channels: int,
1285
+ out_channels: int,
1286
+ bias: bool = True,
1287
+ activation: str = "SwooshL",
1288
+ dropout_p: FloatLike = 0.0,
1289
+ dropout_shared_dim: Optional[int] = -1,
1290
+ initial_scale: float = 1.0,
1291
+ ):
1292
+ super().__init__()
1293
+ # create a temporary module of nn.Linear that we'll steal the
1294
+ # weights and bias from
1295
+ l = ScaledLinear(
1296
+ in_channels, out_channels, bias=bias, initial_scale=initial_scale
1297
+ )
1298
+
1299
+ self.weight = l.weight
1300
+ # register_parameter properly handles making it a parameter when l.bias
1301
+ # is None. I think there is some reason for doing it this way rather
1302
+ # than just setting it to None but I don't know what it is, maybe
1303
+ # something to do with exporting the module..
1304
+ self.register_parameter("bias", l.bias)
1305
+
1306
+ self.activation = activation
1307
+ self.dropout_p = dropout_p
1308
+ self.dropout_shared_dim = dropout_shared_dim
1309
+
1310
+ def forward(self, x: Tensor):
1311
+ if (
1312
+ torch.jit.is_scripting()
1313
+ or torch.jit.is_tracing()
1314
+ or "k2" not in sys.modules
1315
+ ):
1316
+ if self.activation == "SwooshL":
1317
+ x = SwooshLForward(x)
1318
+ elif self.activation == "SwooshR":
1319
+ x = SwooshRForward(x)
1320
+ else:
1321
+ assert False, self.activation
1322
+ return torch.nn.functional.linear(x, self.weight, self.bias)
1323
+
1324
+ return ActivationDropoutAndLinearFunction.apply(
1325
+ x,
1326
+ self.weight,
1327
+ self.bias,
1328
+ self.activation,
1329
+ float(self.dropout_p),
1330
+ self.dropout_shared_dim,
1331
+ )
1332
+
1333
+
1334
+ def _test_whiten():
1335
+ for proportion in [0.1, 0.5, 10.0]:
1336
+ logging.info(f"_test_whiten(): proportion = {proportion}")
1337
+ x = torch.randn(100, 128)
1338
+ direction = torch.randn(128)
1339
+ coeffs = torch.randn(100, 1)
1340
+ x += proportion * direction * coeffs
1341
+
1342
+ x.requires_grad = True
1343
+
1344
+ m = Whiten(
1345
+ 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit,
1346
+ ) # grad_scale
1347
+
1348
+ for _ in range(4):
1349
+ y = m(x)
1350
+
1351
+ y_grad = torch.randn_like(x)
1352
+ y.backward(gradient=y_grad)
1353
+
1354
+ if proportion < 0.2:
1355
+ assert torch.allclose(x.grad, y_grad)
1356
+ elif proportion > 1.0:
1357
+ assert not torch.allclose(x.grad, y_grad)
1358
+
1359
+
1360
+ def _test_balancer_sign():
1361
+ probs = torch.arange(0, 1, 0.01)
1362
+ N = 1000
1363
+ x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0)
1364
+ x = x.detach()
1365
+ x.requires_grad = True
1366
+ m = Balancer(
1367
+ probs.numel(),
1368
+ channel_dim=0,
1369
+ min_positive=0.05,
1370
+ max_positive=0.95,
1371
+ min_abs=0.0,
1372
+ prob=1.0,
1373
+ )
1374
+
1375
+ y_grad = torch.sign(torch.randn(probs.numel(), N))
1376
+
1377
+ y = m(x)
1378
+ y.backward(gradient=y_grad)
1379
+ print("_test_balancer_sign: x = ", x)
1380
+ print("_test_balancer_sign: y grad = ", y_grad)
1381
+ print("_test_balancer_sign: x grad = ", x.grad)
1382
+
1383
+
1384
+ def _test_balancer_magnitude():
1385
+ magnitudes = torch.arange(0, 1, 0.01)
1386
+ N = 1000
1387
+ x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
1388
+ x = x.detach()
1389
+ x.requires_grad = True
1390
+ m = Balancer(
1391
+ magnitudes.numel(),
1392
+ channel_dim=0,
1393
+ min_positive=0.0,
1394
+ max_positive=1.0,
1395
+ min_abs=0.2,
1396
+ max_abs=0.7,
1397
+ prob=1.0,
1398
+ )
1399
+
1400
+ y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
1401
+
1402
+ y = m(x)
1403
+ y.backward(gradient=y_grad)
1404
+ print("_test_balancer_magnitude: x = ", x)
1405
+ print("_test_balancer_magnitude: y grad = ", y_grad)
1406
+ print("_test_balancer_magnitude: x grad = ", x.grad)
1407
+
1408
+
1409
+ def _test_swooshl_deriv():
1410
+ x = torch.randn(10, 12, dtype=torch.double) * 3.0
1411
+ x.requires_grad = True
1412
+ m = SwooshL()
1413
+
1414
+ tol = 1.0 / 255.0
1415
+ torch.autograd.gradcheck(m, x, atol=tol, eps=0.01)
1416
+
1417
+ # for self-test.
1418
+ x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
1419
+ x.requires_grad = True
1420
+ y = m(x)
1421
+ return y
1422
+
1423
+
1424
+ def _test_swooshr_deriv():
1425
+ x = torch.randn(10, 12, dtype=torch.double) * 3.0
1426
+ x.requires_grad = True
1427
+ m = SwooshR()
1428
+
1429
+ tol = 1.0 / 255.0
1430
+ torch.autograd.gradcheck(m, x, atol=tol, eps=0.01)
1431
+
1432
+ # for self-test.
1433
+ x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
1434
+ x.requires_grad = True
1435
+ y = m(x)
1436
+ return y
1437
+
1438
+
1439
+ def _test_softmax():
1440
+ a = torch.randn(2, 10, dtype=torch.float64)
1441
+ b = a.clone()
1442
+ a.requires_grad = True
1443
+ b.requires_grad = True
1444
+ a.softmax(dim=1)[:, 0].sum().backward()
1445
+ print("a grad = ", a.grad)
1446
+ softmax(b, dim=1)[:, 0].sum().backward()
1447
+ print("b grad = ", b.grad)
1448
+ assert torch.allclose(a.grad, b.grad)
1449
+
1450
+
1451
+ def _test_piecewise_linear():
1452
+ p = PiecewiseLinear((0, 10.0))
1453
+ for x in [-100, 0, 100]:
1454
+ assert p(x) == 10.0
1455
+ p = PiecewiseLinear((0, 10.0), (1, 0.0))
1456
+ for x, y in [(-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0)]:
1457
+ print("x, y = ", x, y)
1458
+ assert p(x) == y, (x, p(x), y)
1459
+
1460
+ q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0))
1461
+ x_vals = [-1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0]
1462
+ pq = p.max(q)
1463
+ for x in x_vals:
1464
+ y1 = max(p(x), q(x))
1465
+ y2 = pq(x)
1466
+ assert abs(y1 - y2) < 0.001
1467
+ pq = p.min(q)
1468
+ for x in x_vals:
1469
+ y1 = min(p(x), q(x))
1470
+ y2 = pq(x)
1471
+ assert abs(y1 - y2) < 0.001
1472
+ pq = p + q
1473
+ for x in x_vals:
1474
+ y1 = p(x) + q(x)
1475
+ y2 = pq(x)
1476
+ assert abs(y1 - y2) < 0.001
1477
+
1478
+
1479
+ def _test_activation_dropout_and_linear():
1480
+ in_channels = 20
1481
+ out_channels = 30
1482
+
1483
+ for bias in [True, False]:
1484
+ # actually we don't test for dropout_p != 0.0 because forward functions will
1485
+ # different answers. This is because we are using the k2 implementation of
1486
+ # swoosh_l an swoosh_r inside SwooshL() and SwooshR(), and they call randn()
1487
+ # internally, messing up the random state.
1488
+ for dropout_p in [0.0]:
1489
+ for activation in ["SwooshL", "SwooshR"]:
1490
+ m1 = nn.Sequential(
1491
+ SwooshL() if activation == "SwooshL" else SwooshR(),
1492
+ Dropout3(p=dropout_p, shared_dim=-1),
1493
+ ScaledLinear(
1494
+ in_channels, out_channels, bias=bias, initial_scale=0.5
1495
+ ),
1496
+ )
1497
+ m2 = ActivationDropoutAndLinear(
1498
+ in_channels,
1499
+ out_channels,
1500
+ bias=bias,
1501
+ initial_scale=0.5,
1502
+ activation=activation,
1503
+ dropout_p=dropout_p,
1504
+ )
1505
+ with torch.no_grad():
1506
+ m2.weight[:] = m1[2].weight
1507
+ if bias:
1508
+ m2.bias[:] = m1[2].bias
1509
+ # make sure forward gives same result.
1510
+ x1 = torch.randn(10, in_channels)
1511
+ x1.requires_grad = True
1512
+
1513
+ # TEMP.
1514
+ assert torch.allclose(
1515
+ SwooshRFunction.apply(x1), SwooshRForward(x1), atol=1.0e-03
1516
+ )
1517
+
1518
+ x2 = x1.clone().detach()
1519
+ x2.requires_grad = True
1520
+ seed = 10
1521
+ torch.manual_seed(seed)
1522
+ y1 = m1(x1)
1523
+ y_grad = torch.randn_like(y1)
1524
+ y1.backward(gradient=y_grad)
1525
+ torch.manual_seed(seed)
1526
+ y2 = m2(x2)
1527
+ y2.backward(gradient=y_grad)
1528
+
1529
+ print(
1530
+ f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}"
1531
+ )
1532
+ print("y1 = ", y1)
1533
+ print("y2 = ", y2)
1534
+ assert torch.allclose(y1, y2, atol=0.02)
1535
+ assert torch.allclose(m1[2].weight.grad, m2.weight.grad, atol=1.0e-05)
1536
+ if bias:
1537
+ assert torch.allclose(m1[2].bias.grad, m2.bias.grad, atol=1.0e-05)
1538
+ print("x1.grad = ", x1.grad)
1539
+ print("x2.grad = ", x2.grad)
1540
+
1541
+ def isclose(a, b):
1542
+ # return true if cosine similarity is > 0.9.
1543
+ return (a * b).sum() > 0.9 * (
1544
+ (a**2).sum() * (b**2).sum()
1545
+ ).sqrt()
1546
+
1547
+ # the SwooshL() implementation has a noisy gradient due to 1-byte
1548
+ # storage of it.
1549
+ assert isclose(x1.grad, x2.grad)
1550
+
1551
+
1552
+ if __name__ == "__main__":
1553
+ logging.getLogger().setLevel(logging.DEBUG)
1554
+ torch.set_num_threads(1)
1555
+ torch.set_num_interop_threads(1)
1556
+ _test_piecewise_linear()
1557
+ _test_softmax()
1558
+ _test_whiten()
1559
+ _test_balancer_sign()
1560
+ _test_balancer_magnitude()
1561
+ _test_swooshr_deriv()
1562
+ _test_swooshl_deriv()
1563
+ _test_activation_dropout_and_linear()
zipvoice/models/modules/solver.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2024 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from typing import Optional, Union
19
+
20
+ import torch
21
+
22
+
23
+ class DiffusionModel(torch.nn.Module):
24
+ """A wrapper of diffusion models for inference.
25
+ Args:
26
+ model: The diffusion model.
27
+ func_name: The function name to call.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ model: torch.nn.Module,
33
+ func_name: str = "forward_fm_decoder",
34
+ ):
35
+ super().__init__()
36
+ self.model = model
37
+ self.func_name = func_name
38
+ self.model_func = getattr(self.model, func_name)
39
+
40
+ def forward(
41
+ self,
42
+ t: torch.Tensor,
43
+ x: torch.Tensor,
44
+ text_condition: torch.Tensor,
45
+ speech_condition: torch.Tensor,
46
+ padding_mask: Optional[torch.Tensor] = None,
47
+ guidance_scale: Union[float, torch.Tensor] = 0.0,
48
+ **kwargs
49
+ ) -> torch.Tensor:
50
+ """
51
+ Forward function that Handles the classifier-free guidance.
52
+ Args:
53
+ t: The current timestep, a tensor of a tensor of a single float.
54
+ x: The initial value, with the shape (batch, seq_len, emb_dim).
55
+ text_condition: The text_condition of the diffision model, with
56
+ the shape (batch, seq_len, emb_dim).
57
+ speech_condition: The speech_condition of the diffision model, with the
58
+ shape (batch, seq_len, emb_dim).
59
+ padding_mask: The mask for padding; True means masked position, with the
60
+ shape (batch, seq_len).
61
+ guidance_scale: The scale of classifier-free guidance, a float or a tensor
62
+ of shape (batch, 1, 1).
63
+ Retrun:
64
+ The prediction with the shape (batch, seq_len, emb_dim).
65
+ """
66
+ if not torch.is_tensor(guidance_scale):
67
+ guidance_scale = torch.tensor(
68
+ guidance_scale, dtype=t.dtype, device=t.device
69
+ )
70
+
71
+ if (guidance_scale == 0.0).all():
72
+ return self.model_func(
73
+ t=t,
74
+ xt=x,
75
+ text_condition=text_condition,
76
+ speech_condition=speech_condition,
77
+ padding_mask=padding_mask,
78
+ **kwargs
79
+ )
80
+ else:
81
+ assert t.dim() == 0
82
+
83
+ x = torch.cat([x] * 2, dim=0)
84
+ padding_mask = torch.cat([padding_mask] * 2, dim=0)
85
+
86
+ text_condition = torch.cat(
87
+ [torch.zeros_like(text_condition), text_condition], dim=0
88
+ )
89
+
90
+ if t > 0.5:
91
+ speech_condition = torch.cat(
92
+ [torch.zeros_like(speech_condition), speech_condition], dim=0
93
+ )
94
+ else:
95
+ guidance_scale = guidance_scale * 2
96
+ speech_condition = torch.cat(
97
+ [speech_condition, speech_condition], dim=0
98
+ )
99
+
100
+ data_uncond, data_cond = self.model_func(
101
+ t=t,
102
+ xt=x,
103
+ text_condition=text_condition,
104
+ speech_condition=speech_condition,
105
+ padding_mask=padding_mask,
106
+ **kwargs
107
+ ).chunk(2, dim=0)
108
+
109
+ res = (1 + guidance_scale) * data_cond - guidance_scale * data_uncond
110
+ return res
111
+
112
+
113
+ class DistillDiffusionModel(DiffusionModel):
114
+ """A wrapper of distilled diffusion models for inference.
115
+ Args:
116
+ model: The distilled diffusion model.
117
+ func_name: The function name to call.
118
+ """
119
+
120
+ def __init__(
121
+ self,
122
+ model: torch.nn.Module,
123
+ func_name: str = "forward_fm_decoder",
124
+ ):
125
+ super().__init__(model=model, func_name=func_name)
126
+
127
+ def forward(
128
+ self,
129
+ t: torch.Tensor,
130
+ x: torch.Tensor,
131
+ text_condition: torch.Tensor,
132
+ speech_condition: torch.Tensor,
133
+ padding_mask: Optional[torch.Tensor] = None,
134
+ guidance_scale: Union[float, torch.Tensor] = 0.0,
135
+ **kwargs
136
+ ) -> torch.Tensor:
137
+ """
138
+ Forward function that Handles the classifier-free guidance.
139
+ Args:
140
+ t: The current timestep, a tensor of a single float.
141
+ x: The initial value, with the shape (batch, seq_len, emb_dim).
142
+ text_condition: The text_condition of the diffision model, with
143
+ the shape (batch, seq_len, emb_dim).
144
+ speech_condition: The speech_condition of the diffision model, with the
145
+ shape (batch, seq_len, emb_dim).
146
+ padding_mask: The mask for padding; True means masked position, with the
147
+ shape (batch, seq_len).
148
+ guidance_scale: The scale of classifier-free guidance, a float or a tensor
149
+ of shape (batch, 1, 1).
150
+ Retrun:
151
+ The prediction with the shape (batch, seq_len, emb_dim).
152
+ """
153
+ if not torch.is_tensor(guidance_scale):
154
+ guidance_scale = torch.tensor(
155
+ guidance_scale, dtype=t.dtype, device=t.device
156
+ )
157
+ return self.model_func(
158
+ t=t,
159
+ xt=x,
160
+ text_condition=text_condition,
161
+ speech_condition=speech_condition,
162
+ padding_mask=padding_mask,
163
+ guidance_scale=guidance_scale,
164
+ **kwargs
165
+ )
166
+
167
+
168
+ class EulerSolver:
169
+ def __init__(
170
+ self,
171
+ model: torch.nn.Module,
172
+ func_name: str = "forward_fm_decoder",
173
+ ):
174
+ """Construct a Euler Solver
175
+ Args:
176
+ model: The diffusion model.
177
+ func_name: The function name to call.
178
+ """
179
+
180
+ self.model = DiffusionModel(model, func_name=func_name)
181
+
182
+ def sample(
183
+ self,
184
+ x: torch.Tensor,
185
+ text_condition: torch.Tensor,
186
+ speech_condition: torch.Tensor,
187
+ padding_mask: torch.Tensor,
188
+ num_step: int = 10,
189
+ guidance_scale: Union[float, torch.Tensor] = 0.0,
190
+ t_start: float = 0.0,
191
+ t_end: float = 1.0,
192
+ t_shift: float = 1.0,
193
+ **kwargs
194
+ ) -> torch.Tensor:
195
+ """
196
+ Compute the sample at time `t_end` by Euler Solver.
197
+ Args:
198
+ x: The initial value at time `t_start`, with the shape (batch, seq_len,
199
+ emb_dim).
200
+ text_condition: The text condition of the diffision mode, with the
201
+ shape (batch, seq_len, emb_dim).
202
+ speech_condition: The speech condition of the diffision model, with the
203
+ shape (batch, seq_len, emb_dim).
204
+ padding_mask: The mask for padding; True means masked position, with the
205
+ shape (batch, seq_len).
206
+ num_step: The number of ODE steps.
207
+ guidance_scale: The scale for classifier-free guidance, which is
208
+ a float or a tensor with the shape (batch, 1, 1).
209
+ t_start: the start timestep in the range of [0, 1].
210
+ t_end: the end time_step in the range of [0, 1].
211
+ t_shift: shift the t toward smaller numbers so that the sampling
212
+ will emphasize low SNR region. Should be in the range of (0, 1].
213
+ The shifting will be more significant when the number is smaller.
214
+
215
+ Returns:
216
+ The approximated solution at time `t_end`.
217
+ """
218
+ device = x.device
219
+ assert isinstance(t_start, float) and isinstance(t_end, float)
220
+
221
+ timesteps = get_time_steps(
222
+ t_start=t_start,
223
+ t_end=t_end,
224
+ num_step=num_step,
225
+ t_shift=t_shift,
226
+ device=device,
227
+ )
228
+
229
+ for step in range(num_step):
230
+ v = self.model(
231
+ t=timesteps[step],
232
+ x=x,
233
+ text_condition=text_condition,
234
+ speech_condition=speech_condition,
235
+ padding_mask=padding_mask,
236
+ guidance_scale=guidance_scale,
237
+ **kwargs
238
+ )
239
+ x = x + v * (timesteps[step + 1] - timesteps[step])
240
+ return x
241
+
242
+
243
+ class DistillEulerSolver(EulerSolver):
244
+ def __init__(
245
+ self,
246
+ model: torch.nn.Module,
247
+ func_name: str = "forward_fm_decoder",
248
+ ):
249
+ """Construct a Euler Solver for distilled diffusion models.
250
+ Args:
251
+ model: The diffusion model.
252
+ """
253
+ self.model = DistillDiffusionModel(model, func_name=func_name)
254
+
255
+
256
+ def get_time_steps(
257
+ t_start: float = 0.0,
258
+ t_end: float = 1.0,
259
+ num_step: int = 10,
260
+ t_shift: float = 1.0,
261
+ device: torch.device = torch.device("cpu"),
262
+ ) -> torch.Tensor:
263
+ """Compute the intermediate time steps for sampling.
264
+
265
+ Args:
266
+ t_start: The starting time of the sampling (default is 0).
267
+ t_end: The starting time of the sampling (default is 1).
268
+ num_step: The number of sampling.
269
+ t_shift: shift the t toward smaller numbers so that the sampling
270
+ will emphasize low SNR region. Should be in the range of (0, 1].
271
+ The shifting will be more significant when the number is smaller.
272
+ device: A torch device.
273
+ Returns:
274
+ The time step with the shape (num_step + 1,).
275
+ """
276
+
277
+ timesteps = torch.linspace(t_start, t_end, num_step + 1).to(device)
278
+
279
+ timesteps = t_shift * timesteps / (1 + (t_shift - 1) * timesteps)
280
+
281
+ return timesteps
zipvoice/models/modules/zipformer.py ADDED
@@ -0,0 +1,1680 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2022-2024 Xiaomi Corp. (authors: Daniel Povey,
3
+ # Zengwei Yao,
4
+ # Wei Kang
5
+ # Han Zhu)
6
+ #
7
+ # See ../../../../LICENSE for clarification regarding multiple authors
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ import copy
22
+ import logging
23
+ import math
24
+ import random
25
+ from typing import Optional, Tuple, Union
26
+
27
+ import torch
28
+ from torch import Tensor, nn
29
+
30
+ from zipvoice.models.modules.scaling import (
31
+ ActivationDropoutAndLinear,
32
+ Balancer,
33
+ BiasNorm,
34
+ Dropout2,
35
+ FloatLike,
36
+ Identity,
37
+ ScaledLinear,
38
+ ScheduledFloat,
39
+ SwooshR,
40
+ Whiten,
41
+ limit_param_value,
42
+ penalize_abs_values_gt,
43
+ softmax,
44
+ )
45
+
46
+
47
+ def timestep_embedding(timesteps, dim, max_period=10000):
48
+ """Create sinusoidal timestep embeddings.
49
+
50
+ :param timesteps: shape of (N) or (N, T)
51
+ :param dim: the dimension of the output.
52
+ :param max_period: controls the minimum frequency of the embeddings.
53
+ :return: an Tensor of positional embeddings. shape of (N, dim) or (T, N, dim)
54
+ """
55
+ half = dim // 2
56
+ freqs = torch.exp(
57
+ -math.log(max_period)
58
+ * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device)
59
+ / half
60
+ )
61
+
62
+ if timesteps.dim() == 2:
63
+ timesteps = timesteps.transpose(0, 1) # (N, T) -> (T, N)
64
+
65
+ args = timesteps[..., None].float() * freqs[None]
66
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
67
+ if dim % 2:
68
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[..., :1])], dim=-1)
69
+ return embedding
70
+
71
+
72
+ class TTSZipformer(nn.Module):
73
+ """
74
+ Args:
75
+
76
+ Note: all "int or Tuple[int]" arguments below will be treated as lists of the same
77
+ length as downsampling_factor if they are single ints or one-element tuples.
78
+ The length of downsampling_factor defines the number of stacks.
79
+
80
+ downsampling_factor (Tuple[int]): downsampling factor for each encoder stack.
81
+ Note: this is in addition to the downsampling factor of 2 that is applied in
82
+ the frontend (self.encoder_embed).
83
+ encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks,
84
+ one per encoder stack.
85
+ num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
86
+ query_head_dim (int or Tuple[int]): dimension of query and key per attention
87
+ head: per stack, if a tuple..
88
+ pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection
89
+ per attention head
90
+ value_head_dim (int or Tuple[int]): dimension of value in each attention head
91
+ num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism.
92
+ Must be at least 4.
93
+ feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules
94
+ cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module
95
+
96
+ pos_dim (int): the dimension of each positional-encoding vector prior to
97
+ projection, e.g. 128.
98
+
99
+ dropout (float): dropout rate
100
+ warmup_batches (float): number of batches to warm up over; this controls
101
+ dropout of encoder layers.
102
+ use_time_embed: (bool): if True, take time embedding as an additional input.
103
+ time_embed_dim: (int): the dimension of the time embedding.
104
+ use_guidance_scale_embed (bool): if True, take guidance scale embedding as
105
+ an additional input.
106
+ guidance_scale_embed_dim: (int): the dimension of the guidance scale embedding.
107
+ """
108
+
109
+ def __init__(
110
+ self,
111
+ in_dim: int,
112
+ out_dim: int,
113
+ downsampling_factor: Union[int, Tuple[int]] = (2, 4),
114
+ num_encoder_layers: Union[int, Tuple[int]] = 4,
115
+ cnn_module_kernel: Union[int, Tuple[int]] = 31,
116
+ encoder_dim: int = 384,
117
+ query_head_dim: int = 24,
118
+ pos_head_dim: int = 4,
119
+ value_head_dim: int = 12,
120
+ num_heads: int = 8,
121
+ feedforward_dim: int = 1536,
122
+ pos_dim: int = 192,
123
+ dropout: FloatLike = None, # see code below for default
124
+ warmup_batches: float = 4000.0,
125
+ use_time_embed: bool = True,
126
+ time_embed_dim: int = 192,
127
+ use_guidance_scale_embed: bool = False,
128
+ guidance_scale_embed_dim: int = 192,
129
+ use_conv: bool = True,
130
+ ) -> None:
131
+ super(TTSZipformer, self).__init__()
132
+
133
+ if dropout is None:
134
+ dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1))
135
+ if isinstance(downsampling_factor, int):
136
+ downsampling_factor = (downsampling_factor,)
137
+
138
+ def _to_tuple(x):
139
+ """Converts a single int or a 1-tuple of an int to a tuple with the same
140
+ length as downsampling_factor"""
141
+ if isinstance(x, int):
142
+ x = (x,)
143
+ if len(x) == 1:
144
+ x = x * len(downsampling_factor)
145
+ else:
146
+ assert len(x) == len(downsampling_factor) and isinstance(x[0], int)
147
+ return x
148
+
149
+ def _assert_downsampling_factor(factors):
150
+ """assert downsampling_factor follows u-net style"""
151
+ assert factors[0] == 1 and factors[-1] == 1
152
+
153
+ for i in range(1, len(factors) // 2 + 1):
154
+ assert factors[i] == factors[i - 1] * 2
155
+
156
+ for i in range(len(factors) // 2 + 1, len(factors)):
157
+ assert factors[i] * 2 == factors[i - 1]
158
+
159
+ _assert_downsampling_factor(downsampling_factor)
160
+ self.downsampling_factor = downsampling_factor # tuple
161
+ num_encoder_layers = _to_tuple(num_encoder_layers)
162
+ self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
163
+ self.encoder_dim = encoder_dim
164
+ self.num_encoder_layers = num_encoder_layers
165
+ self.query_head_dim = query_head_dim
166
+ self.value_head_dim = value_head_dim
167
+ self.num_heads = num_heads
168
+
169
+ self.use_time_embed = use_time_embed
170
+ self.use_guidance_scale_embed = use_guidance_scale_embed
171
+
172
+ self.time_embed_dim = time_embed_dim
173
+ if self.use_time_embed:
174
+ assert time_embed_dim != -1
175
+ else:
176
+ time_embed_dim = -1
177
+ self.guidance_scale_embed_dim = guidance_scale_embed_dim
178
+
179
+ self.in_proj = nn.Linear(in_dim, encoder_dim)
180
+ self.out_proj = nn.Linear(encoder_dim, out_dim)
181
+
182
+ # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder
183
+ encoders = []
184
+
185
+ num_encoders = len(downsampling_factor)
186
+ for i in range(num_encoders):
187
+ encoder_layer = Zipformer2EncoderLayer(
188
+ embed_dim=encoder_dim,
189
+ pos_dim=pos_dim,
190
+ num_heads=num_heads,
191
+ query_head_dim=query_head_dim,
192
+ pos_head_dim=pos_head_dim,
193
+ value_head_dim=value_head_dim,
194
+ feedforward_dim=feedforward_dim,
195
+ use_conv=use_conv,
196
+ cnn_module_kernel=cnn_module_kernel[i],
197
+ dropout=dropout,
198
+ )
199
+
200
+ # For the segment of the warmup period, we let the Conv2dSubsampling
201
+ # layer learn something. Then we start to warm up the other encoders.
202
+ encoder = Zipformer2Encoder(
203
+ encoder_layer,
204
+ num_encoder_layers[i],
205
+ embed_dim=encoder_dim,
206
+ time_embed_dim=time_embed_dim,
207
+ pos_dim=pos_dim,
208
+ warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
209
+ warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
210
+ final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
211
+ )
212
+
213
+ if downsampling_factor[i] != 1:
214
+ encoder = DownsampledZipformer2Encoder(
215
+ encoder,
216
+ dim=encoder_dim,
217
+ downsample=downsampling_factor[i],
218
+ )
219
+
220
+ encoders.append(encoder)
221
+
222
+ self.encoders = nn.ModuleList(encoders)
223
+ if self.use_time_embed:
224
+ self.time_embed = nn.Sequential(
225
+ nn.Linear(time_embed_dim, time_embed_dim * 2),
226
+ SwooshR(),
227
+ nn.Linear(time_embed_dim * 2, time_embed_dim),
228
+ )
229
+ else:
230
+ self.time_embed = None
231
+
232
+ if self.use_guidance_scale_embed:
233
+ self.guidance_scale_embed = ScaledLinear(
234
+ guidance_scale_embed_dim,
235
+ time_embed_dim,
236
+ bias=False,
237
+ initial_scale=0.1,
238
+ )
239
+ else:
240
+ self.guidance_scale_embed = None
241
+
242
+ def forward(
243
+ self,
244
+ x: Tensor,
245
+ t: Optional[Tensor] = None,
246
+ padding_mask: Optional[Tensor] = None,
247
+ guidance_scale: Optional[Tensor] = None,
248
+ ) -> Tuple[Tensor, Tensor]:
249
+ """
250
+ Args:
251
+ x:
252
+ The input tensor. Its shape is (batch_size, seq_len, feature_dim).
253
+ t:
254
+ A t tensor of shape (batch_size,) or (batch_size, seq_len)
255
+ padding_mask:
256
+ The mask for padding, of shape (batch_size, seq_len); True means
257
+ masked position. May be None.
258
+ guidance_scale:
259
+ The guidance scale in classifier-free guidance of distillation model.
260
+ Returns:
261
+ Return the output embeddings. its shape is
262
+ (batch_size, output_seq_len, encoder_dim)
263
+ """
264
+ x = x.permute(1, 0, 2)
265
+ x = self.in_proj(x)
266
+
267
+ if t is not None:
268
+ assert t.dim() == 1 or t.dim() == 2, t.shape
269
+ time_emb = timestep_embedding(t, self.time_embed_dim)
270
+ if guidance_scale is not None:
271
+ assert (
272
+ guidance_scale.dim() == 1 or guidance_scale.dim() == 2
273
+ ), guidance_scale.shape
274
+ guidance_scale_emb = self.guidance_scale_embed(
275
+ timestep_embedding(guidance_scale, self.guidance_scale_embed_dim)
276
+ )
277
+ time_emb = time_emb + guidance_scale_emb
278
+ time_emb = self.time_embed(time_emb)
279
+ else:
280
+ time_emb = None
281
+
282
+ attn_mask = None
283
+
284
+ for i, module in enumerate(self.encoders):
285
+ x = module(
286
+ x,
287
+ time_emb=time_emb,
288
+ src_key_padding_mask=padding_mask,
289
+ attn_mask=attn_mask,
290
+ )
291
+ x = self.out_proj(x)
292
+ x = x.permute(1, 0, 2)
293
+ return x
294
+
295
+
296
+ def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat:
297
+ return ScheduledFloat((0.0, x), (20000.0, ratio * x), default=x)
298
+
299
+
300
+ class Zipformer2EncoderLayer(nn.Module):
301
+ """
302
+ Args:
303
+ embed_dim: the number of expected features in the input (required).
304
+ nhead: the number of heads in the multiheadattention models (required).
305
+ feedforward_dim: the dimension of the feedforward network model (required).
306
+ dropout: the dropout value (default=0.1).
307
+ cnn_module_kernel (int): Kernel size of convolution module (default=31).
308
+
309
+ Examples::
310
+ >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
311
+ >>> src = torch.rand(10, 32, 512)
312
+ >>> pos_emb = torch.rand(32, 19, 512)
313
+ >>> out = encoder_layer(src, pos_emb)
314
+ """
315
+
316
+ def __init__(
317
+ self,
318
+ embed_dim: int,
319
+ pos_dim: int,
320
+ num_heads: int,
321
+ query_head_dim: int,
322
+ pos_head_dim: int,
323
+ value_head_dim: int,
324
+ feedforward_dim: int,
325
+ dropout: FloatLike = 0.1,
326
+ cnn_module_kernel: int = 31,
327
+ use_conv: bool = True,
328
+ attention_skip_rate: FloatLike = ScheduledFloat(
329
+ (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0
330
+ ),
331
+ conv_skip_rate: FloatLike = ScheduledFloat(
332
+ (0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0
333
+ ),
334
+ const_attention_rate: FloatLike = ScheduledFloat(
335
+ (0.0, 0.25), (4000.0, 0.025), default=0
336
+ ),
337
+ ff2_skip_rate: FloatLike = ScheduledFloat(
338
+ (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)
339
+ ),
340
+ ff3_skip_rate: FloatLike = ScheduledFloat(
341
+ (0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)
342
+ ),
343
+ bypass_skip_rate: FloatLike = ScheduledFloat(
344
+ (0.0, 0.5), (4000.0, 0.02), default=0
345
+ ),
346
+ ) -> None:
347
+ super(Zipformer2EncoderLayer, self).__init__()
348
+ self.embed_dim = embed_dim
349
+
350
+ # self.bypass implements layer skipping as well as bypass.
351
+ self.bypass = BypassModule(
352
+ embed_dim, skip_rate=bypass_skip_rate, straight_through_rate=0
353
+ )
354
+ # bypass_mid is bypass used in the middle of the layer.
355
+ self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0)
356
+
357
+ # skip probability for dynamic modules (meaning: anything but feedforward).
358
+ self.attention_skip_rate = copy.deepcopy(attention_skip_rate)
359
+ # an additional skip probability that applies to ConvModule to stop it from
360
+ # contributing too much early on.
361
+ self.conv_skip_rate = copy.deepcopy(conv_skip_rate)
362
+
363
+ # ff2_skip_rate is to prevent the ff2 module from having output that's too big
364
+ # compared to its residual.
365
+ self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate)
366
+ self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate)
367
+
368
+ self.const_attention_rate = copy.deepcopy(const_attention_rate)
369
+
370
+ self.self_attn_weights = RelPositionMultiheadAttentionWeights(
371
+ embed_dim,
372
+ pos_dim=pos_dim,
373
+ num_heads=num_heads,
374
+ query_head_dim=query_head_dim,
375
+ pos_head_dim=pos_head_dim,
376
+ dropout=0.0,
377
+ )
378
+
379
+ self.self_attn1 = SelfAttention(embed_dim, num_heads, value_head_dim)
380
+
381
+ self.self_attn2 = SelfAttention(embed_dim, num_heads, value_head_dim)
382
+
383
+ self.feed_forward1 = FeedforwardModule(
384
+ embed_dim, (feedforward_dim * 3) // 4, dropout
385
+ )
386
+
387
+ self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout)
388
+
389
+ self.feed_forward3 = FeedforwardModule(
390
+ embed_dim, (feedforward_dim * 5) // 4, dropout
391
+ )
392
+
393
+ self.nonlin_attention = NonlinAttention(
394
+ embed_dim, hidden_channels=3 * embed_dim // 4
395
+ )
396
+
397
+ self.use_conv = use_conv
398
+
399
+ if self.use_conv:
400
+ self.conv_module1 = ConvolutionModule(embed_dim, cnn_module_kernel)
401
+
402
+ self.conv_module2 = ConvolutionModule(embed_dim, cnn_module_kernel)
403
+
404
+ self.norm = BiasNorm(embed_dim)
405
+
406
+ self.balancer1 = Balancer(
407
+ embed_dim,
408
+ channel_dim=-1,
409
+ min_positive=0.45,
410
+ max_positive=0.55,
411
+ min_abs=0.2,
412
+ max_abs=4.0,
413
+ )
414
+
415
+ # balancer for output of NonlinAttentionModule
416
+ self.balancer_na = Balancer(
417
+ embed_dim,
418
+ channel_dim=-1,
419
+ min_positive=0.3,
420
+ max_positive=0.7,
421
+ min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)),
422
+ prob=0.05, # out of concern for memory usage
423
+ )
424
+
425
+ # balancer for output of feedforward2, prevent it from staying too
426
+ # small. give this a very small probability, even at the start of
427
+ # training, it's to fix a rare problem and it's OK to fix it slowly.
428
+ self.balancer_ff2 = Balancer(
429
+ embed_dim,
430
+ channel_dim=-1,
431
+ min_positive=0.3,
432
+ max_positive=0.7,
433
+ min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0),
434
+ max_abs=2.0,
435
+ prob=0.05,
436
+ )
437
+
438
+ self.balancer_ff3 = Balancer(
439
+ embed_dim,
440
+ channel_dim=-1,
441
+ min_positive=0.3,
442
+ max_positive=0.7,
443
+ min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0),
444
+ max_abs=4.0,
445
+ prob=0.05,
446
+ )
447
+
448
+ self.whiten = Whiten(
449
+ num_groups=1,
450
+ whitening_limit=_whitening_schedule(4.0, ratio=3.0),
451
+ prob=(0.025, 0.25),
452
+ grad_scale=0.01,
453
+ )
454
+
455
+ self.balancer2 = Balancer(
456
+ embed_dim,
457
+ channel_dim=-1,
458
+ min_positive=0.45,
459
+ max_positive=0.55,
460
+ min_abs=0.1,
461
+ max_abs=4.0,
462
+ )
463
+
464
+ def get_sequence_dropout_mask(
465
+ self, x: Tensor, dropout_rate: float
466
+ ) -> Optional[Tensor]:
467
+ if (
468
+ dropout_rate == 0.0
469
+ or not self.training
470
+ or torch.jit.is_scripting()
471
+ or torch.jit.is_tracing()
472
+ ):
473
+ return None
474
+ batch_size = x.shape[1]
475
+ mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype)
476
+ return mask
477
+
478
+ def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor:
479
+ """
480
+ Apply sequence-level dropout to x.
481
+ x shape: (seq_len, batch_size, embed_dim)
482
+ """
483
+ dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate)
484
+ if dropout_mask is None:
485
+ return x
486
+ else:
487
+ return x * dropout_mask
488
+
489
+ def forward(
490
+ self,
491
+ src: Tensor,
492
+ pos_emb: Tensor,
493
+ time_emb: Optional[Tensor] = None,
494
+ attn_mask: Optional[Tensor] = None,
495
+ src_key_padding_mask: Optional[Tensor] = None,
496
+ ) -> Tensor:
497
+ """
498
+ Pass the input through the encoder layer.
499
+ Args:
500
+ src: the sequence to the encoder (required):
501
+ shape (seq_len, batch_size, embedding_dim).
502
+ pos_emb: (1, 2*seq_len-1, pos_emb_dim) or
503
+ (batch_size, 2*seq_len-1, pos_emb_dim)
504
+ time_emb: the embedding representing the current timestep
505
+ shape (batch_size, embedding_dim) or (seq_len, batch_size, embedding_dim).
506
+ attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len)
507
+ or (seq_len, seq_len), interpreted as (batch_size, tgt_seq_len, src_seq_len)
508
+ or (tgt_seq_len, src_seq_len). True means masked position. May be None.
509
+ src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len);
510
+ True means masked position. May be None.
511
+
512
+ Returns:
513
+ A tensor which has the same shape as src
514
+ """
515
+ src_orig = src
516
+
517
+ # dropout rate for non-feedforward submodules
518
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
519
+ attention_skip_rate = 0.0
520
+ else:
521
+ attention_skip_rate = (
522
+ float(self.attention_skip_rate) if self.training else 0.0
523
+ )
524
+
525
+ # attn_weights: (num_heads, batch_size, seq_len, seq_len)
526
+ attn_weights = self.self_attn_weights(
527
+ src,
528
+ pos_emb=pos_emb,
529
+ attn_mask=attn_mask,
530
+ key_padding_mask=src_key_padding_mask,
531
+ )
532
+ if time_emb is not None:
533
+
534
+ src = src + time_emb
535
+
536
+ src = src + self.feed_forward1(src)
537
+
538
+ self_attn_dropout_mask = self.get_sequence_dropout_mask(
539
+ src, attention_skip_rate
540
+ )
541
+
542
+ selected_attn_weights = attn_weights[0:1]
543
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
544
+ pass
545
+ elif self.training and random.random() < float(self.const_attention_rate):
546
+ # Make attention weights constant. The intention is to
547
+ # encourage these modules to do something similar to an
548
+ # averaging-over-time operation.
549
+ # only need the mask, can just use the 1st one and expand later
550
+ selected_attn_weights = selected_attn_weights[0:1]
551
+ selected_attn_weights = (selected_attn_weights > 0.0).to(
552
+ selected_attn_weights.dtype
553
+ )
554
+ selected_attn_weights = selected_attn_weights * (
555
+ 1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)
556
+ )
557
+
558
+ na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights))
559
+
560
+ src = src + (
561
+ na if self_attn_dropout_mask is None else na * self_attn_dropout_mask
562
+ )
563
+
564
+ self_attn = self.self_attn1(src, attn_weights)
565
+
566
+ src = src + (
567
+ self_attn
568
+ if self_attn_dropout_mask is None
569
+ else self_attn * self_attn_dropout_mask
570
+ )
571
+
572
+ if self.use_conv:
573
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
574
+ conv_skip_rate = 0.0
575
+ else:
576
+ conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
577
+
578
+ if time_emb is not None:
579
+ src = src + time_emb
580
+
581
+ src = src + self.sequence_dropout(
582
+ self.conv_module1(
583
+ src,
584
+ src_key_padding_mask=src_key_padding_mask,
585
+ ),
586
+ conv_skip_rate,
587
+ )
588
+
589
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
590
+ ff2_skip_rate = 0.0
591
+ else:
592
+ ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0
593
+ src = src + self.sequence_dropout(
594
+ self.balancer_ff2(self.feed_forward2(src)), ff2_skip_rate
595
+ )
596
+
597
+ # bypass in the middle of the layer.
598
+ src = self.bypass_mid(src_orig, src)
599
+
600
+ self_attn = self.self_attn2(src, attn_weights)
601
+
602
+ src = src + (
603
+ self_attn
604
+ if self_attn_dropout_mask is None
605
+ else self_attn * self_attn_dropout_mask
606
+ )
607
+
608
+ if self.use_conv:
609
+
610
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
611
+ conv_skip_rate = 0.0
612
+ else:
613
+ conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
614
+
615
+ if time_emb is not None:
616
+ src = src + time_emb
617
+
618
+ src = src + self.sequence_dropout(
619
+ self.conv_module2(
620
+ src,
621
+ src_key_padding_mask=src_key_padding_mask,
622
+ ),
623
+ conv_skip_rate,
624
+ )
625
+
626
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
627
+ ff3_skip_rate = 0.0
628
+ else:
629
+ ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0
630
+ src = src + self.sequence_dropout(
631
+ self.balancer_ff3(self.feed_forward3(src)), ff3_skip_rate
632
+ )
633
+
634
+ src = self.balancer1(src)
635
+ src = self.norm(src)
636
+
637
+ src = self.bypass(src_orig, src)
638
+
639
+ src = self.balancer2(src)
640
+ src = self.whiten(src)
641
+
642
+ return src
643
+
644
+
645
+ class Zipformer2Encoder(nn.Module):
646
+ r"""Zipformer2Encoder is a stack of N encoder layers
647
+
648
+ Args:
649
+ encoder_layer: an instance of the Zipformer2EncoderLayer() class (required).
650
+ num_layers: the number of sub-encoder-layers in the encoder (required).
651
+ pos_dim: the dimension for the relative positional encoding
652
+
653
+ Examples::
654
+ >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
655
+ >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6)
656
+ >>> src = torch.rand(10, 32, 512)
657
+ >>> out = zipformer_encoder(src)
658
+ """
659
+
660
+ def __init__(
661
+ self,
662
+ encoder_layer: nn.Module,
663
+ num_layers: int,
664
+ embed_dim: int,
665
+ time_embed_dim: int,
666
+ pos_dim: int,
667
+ warmup_begin: float,
668
+ warmup_end: float,
669
+ initial_layerdrop_rate: float = 0.5,
670
+ final_layerdrop_rate: float = 0.05,
671
+ ) -> None:
672
+ super().__init__()
673
+ self.encoder_pos = CompactRelPositionalEncoding(
674
+ pos_dim, dropout_rate=0.15, length_factor=1.0
675
+ )
676
+ if time_embed_dim != -1:
677
+ self.time_emb = nn.Sequential(
678
+ SwooshR(),
679
+ nn.Linear(time_embed_dim, embed_dim),
680
+ )
681
+ else:
682
+ self.time_emb = None
683
+
684
+ self.layers = nn.ModuleList(
685
+ [copy.deepcopy(encoder_layer) for i in range(num_layers)]
686
+ )
687
+ self.num_layers = num_layers
688
+
689
+ assert 0 <= warmup_begin <= warmup_end
690
+
691
+ delta = (1.0 / num_layers) * (warmup_end - warmup_begin)
692
+ cur_begin = warmup_begin # interpreted as a training batch index
693
+ for i in range(num_layers):
694
+ cur_end = cur_begin + delta
695
+ self.layers[i].bypass.skip_rate = ScheduledFloat(
696
+ (cur_begin, initial_layerdrop_rate),
697
+ (cur_end, final_layerdrop_rate),
698
+ default=0.0,
699
+ )
700
+ cur_begin = cur_end
701
+
702
+ def forward(
703
+ self,
704
+ src: Tensor,
705
+ time_emb: Optional[Tensor] = None,
706
+ attn_mask: Optional[Tensor] = None,
707
+ src_key_padding_mask: Optional[Tensor] = None,
708
+ ) -> Tensor:
709
+ r"""Pass the input through the encoder layers in turn.
710
+
711
+ Args:
712
+ src: the sequence to the encoder (required):
713
+ shape (seq_len, batch_size, embedding_dim).
714
+ time_emb: the embedding representing the current timestep:
715
+ shape (batch_size, embedding_dim)
716
+ or (seq_len, batch_size, embedding_dim) .
717
+ attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len)
718
+ or (seq_len, seq_len), interpreted as
719
+ (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
720
+ True means masked position. May be None.
721
+ src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len);
722
+ True means masked position. May be None.
723
+
724
+ Returns: a Tensor with the same shape as src.
725
+ """
726
+ pos_emb = self.encoder_pos(src)
727
+ if self.time_emb is not None:
728
+ assert time_emb is not None
729
+ time_emb = self.time_emb(time_emb)
730
+ else:
731
+ assert time_emb is None
732
+
733
+ output = src
734
+
735
+ for i, mod in enumerate(self.layers):
736
+ output = mod(
737
+ output,
738
+ pos_emb,
739
+ time_emb=time_emb,
740
+ attn_mask=attn_mask,
741
+ src_key_padding_mask=src_key_padding_mask,
742
+ )
743
+
744
+ return output
745
+
746
+
747
+ class BypassModule(nn.Module):
748
+ """
749
+ An nn.Module that implements a learnable bypass scale, and also randomized
750
+ per-sequence layer-skipping. The bypass is limited during early stages of training
751
+ to be close to "straight-through", i.e. to not do the bypass operation much
752
+ initially, in order to force all the modules to learn something.
753
+ """
754
+
755
+ def __init__(
756
+ self,
757
+ embed_dim: int,
758
+ skip_rate: FloatLike = 0.0,
759
+ straight_through_rate: FloatLike = 0.0,
760
+ scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0),
761
+ scale_max: FloatLike = 1.0,
762
+ ):
763
+ super().__init__()
764
+ self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
765
+ self.skip_rate = copy.deepcopy(skip_rate)
766
+ self.straight_through_rate = copy.deepcopy(straight_through_rate)
767
+ self.scale_min = copy.deepcopy(scale_min)
768
+ self.scale_max = copy.deepcopy(scale_max)
769
+
770
+ def _get_bypass_scale(self, batch_size: int):
771
+ # returns bypass-scale of shape (num_channels,),
772
+ # or (batch_size, num_channels,). This is actually the
773
+ # scale on the non-residual term, so 0 corresponds to bypassing
774
+ # this module.
775
+ if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
776
+ return self.bypass_scale
777
+ else:
778
+ ans = limit_param_value(
779
+ self.bypass_scale,
780
+ min=float(self.scale_min),
781
+ max=float(self.scale_max),
782
+ )
783
+ skip_rate = float(self.skip_rate)
784
+ if skip_rate != 0.0:
785
+ mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate
786
+ ans = ans * mask
787
+ # now ans is of shape (batch_size, num_channels), and is zero for
788
+ # sequences on which we have randomly chosen to do layer-skipping.
789
+ straight_through_rate = float(self.straight_through_rate)
790
+ if straight_through_rate != 0.0:
791
+ mask = (
792
+ torch.rand((batch_size, 1), device=ans.device)
793
+ < straight_through_rate
794
+ )
795
+ ans = torch.maximum(ans, mask.to(ans.dtype))
796
+ return ans
797
+
798
+ def forward(self, src_orig: Tensor, src: Tensor):
799
+ """
800
+ Args: src_orig and src are both of shape (seq_len, batch_size, num_channels)
801
+ Returns: something with the same shape as src and src_orig
802
+ """
803
+ bypass_scale = self._get_bypass_scale(src.shape[1])
804
+ return src_orig + (src - src_orig) * bypass_scale
805
+
806
+
807
+ class DownsampledZipformer2Encoder(nn.Module):
808
+ r"""
809
+ DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame
810
+ rate, after convolutional downsampling, and then upsampled again at the output, and
811
+ combined with the origin input, so that the output has the same shape as the input.
812
+ """
813
+
814
+ def __init__(self, encoder: nn.Module, dim: int, downsample: int):
815
+ super(DownsampledZipformer2Encoder, self).__init__()
816
+ self.downsample_factor = downsample
817
+ self.downsample = SimpleDownsample(downsample)
818
+ self.num_layers = encoder.num_layers
819
+ self.encoder = encoder
820
+ self.upsample = SimpleUpsample(downsample)
821
+ self.out_combiner = BypassModule(dim, straight_through_rate=0)
822
+
823
+ def forward(
824
+ self,
825
+ src: Tensor,
826
+ time_emb: Optional[Tensor] = None,
827
+ attn_mask: Optional[Tensor] = None,
828
+ src_key_padding_mask: Optional[Tensor] = None,
829
+ ) -> Tensor:
830
+ r"""Downsample, go through encoder, upsample.
831
+
832
+ Args:
833
+ src: the sequence to the encoder (required):
834
+ shape (seq_len, batch_size, embedding_dim).
835
+ time_emb: the embedding representing the current timestep:
836
+ shape (batch_size, embedding_dim)
837
+ or (seq_len, batch_size, embedding_dim) .
838
+ feature_mask: something that broadcasts with src, that we'll multiply `src`
839
+ by at every layer: if a Tensor, likely of shape
840
+ (seq_len, batch_size, embedding_dim)
841
+ attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len)
842
+ or (seq_len, seq_len), interpreted as
843
+ (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
844
+ True means masked position. May be None.
845
+ src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len);
846
+ True means masked position. May be None.
847
+
848
+ Returns: a Tensor with the same shape as src.
849
+ """
850
+ src_orig = src
851
+ src = self.downsample(src)
852
+ ds = self.downsample_factor
853
+ if time_emb is not None and time_emb.dim() == 3:
854
+ time_emb = time_emb[::ds]
855
+ if attn_mask is not None:
856
+ attn_mask = attn_mask[::ds, ::ds]
857
+ if src_key_padding_mask is not None:
858
+ src_key_padding_mask = src_key_padding_mask[..., ::ds]
859
+
860
+ src = self.encoder(
861
+ src,
862
+ time_emb=time_emb,
863
+ attn_mask=attn_mask,
864
+ src_key_padding_mask=src_key_padding_mask,
865
+ )
866
+ src = self.upsample(src)
867
+ # remove any extra frames that are not a multiple of downsample_factor
868
+ src = src[: src_orig.shape[0]]
869
+
870
+ return self.out_combiner(src_orig, src)
871
+
872
+
873
+ class SimpleDownsample(torch.nn.Module):
874
+ """
875
+ Does downsampling with attention, by weighted sum.
876
+ """
877
+
878
+ def __init__(self, downsample: int):
879
+ super(SimpleDownsample, self).__init__()
880
+
881
+ self.bias = nn.Parameter(torch.zeros(downsample))
882
+
883
+ self.name = None # will be set from training code
884
+
885
+ self.downsample = downsample
886
+
887
+ def forward(self, src: Tensor) -> Tensor:
888
+ """
889
+ x: (seq_len, batch_size, in_channels)
890
+ Returns a tensor of shape
891
+ ( (seq_len+downsample-1)//downsample, batch_size, channels)
892
+ """
893
+ (seq_len, batch_size, in_channels) = src.shape
894
+ ds = self.downsample
895
+ d_seq_len = (seq_len + ds - 1) // ds
896
+
897
+ # Pad to an exact multiple of self.downsample
898
+ # right-pad src, repeating the last element.
899
+ pad = d_seq_len * ds - seq_len
900
+ src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2])
901
+ src = torch.cat((src, src_extra), dim=0)
902
+ assert src.shape[0] == d_seq_len * ds
903
+
904
+ src = src.reshape(d_seq_len, ds, batch_size, in_channels)
905
+
906
+ weights = self.bias.softmax(dim=0)
907
+ # weights: (downsample, 1, 1)
908
+ weights = weights.unsqueeze(-1).unsqueeze(-1)
909
+
910
+ # ans1 is the first `in_channels` channels of the output
911
+ ans = (src * weights).sum(dim=1)
912
+
913
+ return ans
914
+
915
+
916
+ class SimpleUpsample(torch.nn.Module):
917
+ """
918
+ A very simple form of upsampling that just repeats the input.
919
+ """
920
+
921
+ def __init__(self, upsample: int):
922
+ super(SimpleUpsample, self).__init__()
923
+ self.upsample = upsample
924
+
925
+ def forward(self, src: Tensor) -> Tensor:
926
+ """
927
+ x: (seq_len, batch_size, num_channels)
928
+ Returns a tensor of shape
929
+ ( (seq_len*upsample), batch_size, num_channels)
930
+ """
931
+ upsample = self.upsample
932
+ (seq_len, batch_size, num_channels) = src.shape
933
+ src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels)
934
+ src = src.reshape(seq_len * upsample, batch_size, num_channels)
935
+ return src
936
+
937
+
938
+ class CompactRelPositionalEncoding(torch.nn.Module):
939
+ """
940
+ Relative positional encoding module. This version is "compact" meaning it is able
941
+ to encode the important information about the relative position in a relatively
942
+ small number of dimensions. The goal is to make it so that small differences between
943
+ large relative offsets (e.g. 1000 vs. 1001) make very little difference to the
944
+ embedding. Such differences were potentially important when encoding absolute
945
+ position, but not important when encoding relative position because there is now no
946
+ need to compare two large offsets with each other.
947
+
948
+ Our embedding works by projecting the interval [-infinity,infinity] to a finite
949
+ interval using the atan() function, before doing the Fourier transform of that fixed
950
+ interval. The atan() function would compress the "long tails" too small, making it
951
+ hard to distinguish between different magnitudes of large offsets, so we use a
952
+ logarithmic function to compress large offsets to a smaller range before applying
953
+ atan(). Scalings are chosen in such a way that the embedding can clearly distinguish
954
+ individual offsets as long as they are quite close to the origin, e.g. abs(offset)
955
+ <= about sqrt(embedding_dim)
956
+
957
+
958
+ Args:
959
+ embed_dim: Embedding dimension.
960
+ dropout_rate: Dropout rate.
961
+ max_len: Maximum input length: just a heuristic for initialization.
962
+ length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives
963
+ less weight to small differences of offset near the origin.
964
+ """
965
+
966
+ def __init__(
967
+ self,
968
+ embed_dim: int,
969
+ dropout_rate: FloatLike,
970
+ max_len: int = 1000,
971
+ length_factor: float = 1.0,
972
+ ) -> None:
973
+ """Construct a CompactRelPositionalEncoding object."""
974
+ super(CompactRelPositionalEncoding, self).__init__()
975
+ self.embed_dim = embed_dim
976
+ assert embed_dim % 2 == 0, embed_dim
977
+ self.dropout = Dropout2(dropout_rate)
978
+ self.pe = None
979
+ assert length_factor >= 1.0, length_factor
980
+ self.length_factor = length_factor
981
+ self.extend_pe(torch.tensor(0.0).expand(max_len))
982
+
983
+ def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None:
984
+ """Reset the positional encodings."""
985
+ T = x.size(0) + left_context_len
986
+
987
+ if self.pe is not None:
988
+ # self.pe contains both positive and negative parts
989
+ # the length of self.pe is 2 * input_len - 1
990
+ if self.pe.size(0) >= T * 2 - 1:
991
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
992
+ return
993
+
994
+ # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
995
+ x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1)
996
+
997
+ freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device)
998
+
999
+ # `compression_length` this is arbitrary/heuristic, if it is larger we have more
1000
+ # resolution for small time offsets but less resolution for large time offsets.
1001
+ compression_length = self.embed_dim**0.5
1002
+ # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity
1003
+ # to infinity; but it does so more slowly than T for large absolute values of T.
1004
+ # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which is
1005
+ # important.
1006
+ x_compressed = (
1007
+ compression_length
1008
+ * x.sign()
1009
+ * ((x.abs() + compression_length).log() - math.log(compression_length))
1010
+ )
1011
+
1012
+ # if self.length_factor == 1.0, then length_scale is chosen so that the
1013
+ # FFT can exactly separate points close to the origin (T == 0). So this
1014
+ # part of the formulation is not really heuristic.
1015
+ # But empirically, for ASR at least, length_factor > 1.0 seems to work better.
1016
+ length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi)
1017
+
1018
+ # note for machine implementations: if atan is not available, we can use:
1019
+ # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2)
1020
+ # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 ,
1021
+ # atan(x))
1022
+ x_atan = (x_compressed / length_scale).atan() # results between -pi and pi
1023
+
1024
+ cosines = (x_atan * freqs).cos()
1025
+ sines = (x_atan * freqs).sin()
1026
+
1027
+ pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device)
1028
+ pe[:, 0::2] = cosines
1029
+ pe[:, 1::2] = sines
1030
+ pe[:, -1] = 1.0 # for bias.
1031
+
1032
+ self.pe = pe.to(dtype=x.dtype)
1033
+
1034
+ def forward(self, x: Tensor, left_context_len: int = 0) -> Tensor:
1035
+ """Create positional encoding.
1036
+
1037
+ Args:
1038
+ x (Tensor): Input tensor (time, batch, `*`).
1039
+ left_context_len: (int): Length of cached left context.
1040
+
1041
+ Returns:
1042
+ positional embedding, of shape (batch, left_context_len + 2*time-1, `*`).
1043
+ """
1044
+ self.extend_pe(x, left_context_len)
1045
+ x_size_left = x.size(0) + left_context_len
1046
+ # length of positive side: x.size(0) + left_context_len
1047
+ # length of negative side: x.size(0)
1048
+ pos_emb = self.pe[
1049
+ self.pe.size(0) // 2
1050
+ - x_size_left
1051
+ + 1 : self.pe.size(0) // 2 # noqa E203
1052
+ + x.size(0),
1053
+ :,
1054
+ ]
1055
+ pos_emb = pos_emb.unsqueeze(0)
1056
+ return self.dropout(pos_emb)
1057
+
1058
+
1059
+ class RelPositionMultiheadAttentionWeights(nn.Module):
1060
+ r"""Module that computes multi-head attention weights with relative position
1061
+ encoding. Various other modules consume the resulting attention weights:
1062
+ see, for example, the SimpleAttention module which allows you to compute
1063
+ conventional attention.
1064
+
1065
+ This is a quite heavily modified from: "Transformer-XL: Attentive Language
1066
+ Models Beyond a Fixed-Length Context",
1067
+ we have to write up the differences.
1068
+
1069
+
1070
+ Args:
1071
+ embed_dim: number of channels at the input to this module, e.g. 256
1072
+ pos_dim: dimension of the positional encoding vectors, e.g. 128.
1073
+ num_heads: number of heads to compute weights for, e.g. 8
1074
+ query_head_dim: dimension of the query (and key), per head. e.g. 24.
1075
+ pos_head_dim: dimension of the projected positional encoding per head, e.g. 4.
1076
+ dropout: dropout probability for attn_output_weights. Default: 0.0.
1077
+ pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on
1078
+ any given call to forward(), in training time.
1079
+ """
1080
+
1081
+ def __init__(
1082
+ self,
1083
+ embed_dim: int,
1084
+ pos_dim: int,
1085
+ num_heads: int,
1086
+ query_head_dim: int,
1087
+ pos_head_dim: int,
1088
+ dropout: float = 0.0,
1089
+ pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)),
1090
+ ) -> None:
1091
+ super().__init__()
1092
+ self.embed_dim = embed_dim
1093
+ self.num_heads = num_heads
1094
+ self.query_head_dim = query_head_dim
1095
+ self.pos_head_dim = pos_head_dim
1096
+ self.dropout = dropout
1097
+ self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
1098
+ self.name = None # will be overwritten in training code; for diagnostics.
1099
+
1100
+ key_head_dim = query_head_dim
1101
+ in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads
1102
+
1103
+ # the initial_scale is supposed to take over the "scaling" factor of
1104
+ # head_dim ** -0.5 that has been used in previous forms of attention,
1105
+ # dividing it between the query and key. Note: this module is intended
1106
+ # to be used with the ScaledAdam optimizer; with most other optimizers,
1107
+ # it would be necessary to apply the scaling factor in the forward function.
1108
+ self.in_proj = ScaledLinear(
1109
+ embed_dim,
1110
+ in_proj_dim,
1111
+ bias=True,
1112
+ initial_scale=query_head_dim**-0.25,
1113
+ )
1114
+
1115
+ self.whiten_keys = Whiten(
1116
+ num_groups=num_heads,
1117
+ whitening_limit=_whitening_schedule(3.0),
1118
+ prob=(0.025, 0.25),
1119
+ grad_scale=0.025,
1120
+ )
1121
+
1122
+ # add a balancer for the keys that runs with very small probability, and
1123
+ # tries to enforce that all dimensions have mean around zero. The
1124
+ # weights produced by this module are invariant to adding a constant to
1125
+ # the keys, so the derivative of the bias is mathematically zero; but
1126
+ # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero
1127
+ # bias because the small numerical roundoff tends to have a non-random
1128
+ # sign. This module is intended to prevent that. Use a very small
1129
+ # probability; that should be sufficient to fix the problem.
1130
+ self.balance_keys = Balancer(
1131
+ key_head_dim * num_heads,
1132
+ channel_dim=-1,
1133
+ min_positive=0.4,
1134
+ max_positive=0.6,
1135
+ min_abs=0.0,
1136
+ max_abs=100.0,
1137
+ prob=0.025,
1138
+ )
1139
+
1140
+ # linear transformation for positional encoding.
1141
+ self.linear_pos = ScaledLinear(
1142
+ pos_dim, num_heads * pos_head_dim, bias=False, initial_scale=0.05
1143
+ )
1144
+
1145
+ # the following are for diagnostics only, see --print-diagnostics option
1146
+ self.copy_pos_query = Identity()
1147
+ self.copy_query = Identity()
1148
+
1149
+ def forward(
1150
+ self,
1151
+ x: Tensor,
1152
+ pos_emb: Tensor,
1153
+ key_padding_mask: Optional[Tensor] = None,
1154
+ attn_mask: Optional[Tensor] = None,
1155
+ ) -> Tensor:
1156
+ r"""
1157
+ Args:
1158
+ x: input of shape (seq_len, batch_size, embed_dim)
1159
+ pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim)
1160
+ key_padding_mask: a bool tensor of shape (batch_size, seq_len).
1161
+ Positions that are True in this mask will be ignored as sources in the
1162
+ attention weighting.
1163
+ attn_mask: mask of shape (seq_len, seq_len) or
1164
+ (batch_size, seq_len, seq_len), interpreted as
1165
+ ([batch_size,] tgt_seq_len, src_seq_len)
1166
+ saying which positions are allowed to attend to which other positions.
1167
+ Returns:
1168
+ a tensor of attention weights, of
1169
+ shape (hum_heads, batch_size, seq_len, seq_len)
1170
+ interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len).
1171
+ """
1172
+ x = self.in_proj(x)
1173
+ query_head_dim = self.query_head_dim
1174
+ pos_head_dim = self.pos_head_dim
1175
+ num_heads = self.num_heads
1176
+
1177
+ seq_len, batch_size, _ = x.shape
1178
+
1179
+ query_dim = query_head_dim * num_heads
1180
+
1181
+ # self-attention
1182
+ q = x[..., 0:query_dim]
1183
+ k = x[..., query_dim : 2 * query_dim]
1184
+ # p is the position-encoding query
1185
+ p = x[..., 2 * query_dim :]
1186
+ assert p.shape[-1] == num_heads * pos_head_dim, (
1187
+ p.shape[-1],
1188
+ num_heads,
1189
+ pos_head_dim,
1190
+ )
1191
+
1192
+ q = self.copy_query(q) # for diagnostics only, does nothing.
1193
+ k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass.
1194
+ p = self.copy_pos_query(p) # for diagnostics only, does nothing.
1195
+
1196
+ q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
1197
+ p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim)
1198
+ k = k.reshape(seq_len, batch_size, num_heads, query_head_dim)
1199
+
1200
+ # time1 refers to target, time2 refers to source.
1201
+ q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim)
1202
+ p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim)
1203
+ k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2)
1204
+
1205
+ attn_scores = torch.matmul(q, k)
1206
+
1207
+ use_pos_scores = False
1208
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
1209
+ # We can't put random.random() in the same line
1210
+ use_pos_scores = True
1211
+ elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
1212
+ use_pos_scores = True
1213
+
1214
+ if use_pos_scores:
1215
+ pos_emb = self.linear_pos(pos_emb)
1216
+ seq_len2 = 2 * seq_len - 1
1217
+ pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
1218
+ 2, 0, 3, 1
1219
+ )
1220
+ # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
1221
+
1222
+ # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head,
1223
+ # batch, time1, seq_len2) [where seq_len2 represents relative position.]
1224
+ pos_scores = torch.matmul(p, pos_emb)
1225
+ # the following .as_strided() expression converts the last axis of
1226
+ # pos_scores from relative to absolute position. I don't know whether I
1227
+ # might have got the time-offsets backwards or not, but let this code define
1228
+ # which way round it is supposed to be.
1229
+ if torch.jit.is_tracing():
1230
+ (num_heads, batch_size, time1, n) = pos_scores.shape
1231
+ rows = torch.arange(start=time1 - 1, end=-1, step=-1)
1232
+ cols = torch.arange(seq_len)
1233
+ rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
1234
+ indexes = rows + cols
1235
+ pos_scores = pos_scores.reshape(-1, n)
1236
+ pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
1237
+ pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len)
1238
+ else:
1239
+ pos_scores = pos_scores.as_strided(
1240
+ (num_heads, batch_size, seq_len, seq_len),
1241
+ (
1242
+ pos_scores.stride(0),
1243
+ pos_scores.stride(1),
1244
+ pos_scores.stride(2) - pos_scores.stride(3),
1245
+ pos_scores.stride(3),
1246
+ ),
1247
+ storage_offset=pos_scores.stride(3) * (seq_len - 1),
1248
+ )
1249
+
1250
+ attn_scores = attn_scores + pos_scores
1251
+
1252
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
1253
+ pass
1254
+ elif self.training and random.random() < 0.1:
1255
+ # This is a harder way of limiting the attention scores to not be
1256
+ # too large. It incurs a penalty if any of them has an absolute
1257
+ # value greater than 50.0. this should be outside the normal range
1258
+ # of the attention scores. We use this mechanism instead of, say,
1259
+ # something added to the loss function involving the entropy,
1260
+ # because once the entropy gets very small gradients through the
1261
+ # softmax can become very small, and we'd get zero derivatives. The
1262
+ # choices of 1.0e-04 as the scale on the penalty makes this
1263
+ # mechanism vulnerable to the absolute scale of the loss function,
1264
+ # but we view this as a failsafe to avoid "implausible" parameter
1265
+ # values rather than a regularization method that should be active
1266
+ # under normal circumstances.
1267
+ attn_scores = penalize_abs_values_gt(
1268
+ attn_scores, limit=25.0, penalty=1.0e-04, name=self.name
1269
+ )
1270
+
1271
+ assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len)
1272
+
1273
+ if attn_mask is not None:
1274
+ assert attn_mask.dtype == torch.bool
1275
+ # use -1000 to avoid nan's where attn_mask and key_padding_mask make
1276
+ # all scores zero. It's important that this be large enough that exp(-1000)
1277
+ # is exactly zero, for reasons related to const_attention_rate, it
1278
+ # compares the final weights with zero.
1279
+ attn_scores = attn_scores.masked_fill(attn_mask, -1000)
1280
+
1281
+ if key_padding_mask is not None:
1282
+ assert key_padding_mask.shape == (
1283
+ batch_size,
1284
+ seq_len,
1285
+ ), key_padding_mask.shape
1286
+ attn_scores = attn_scores.masked_fill(
1287
+ key_padding_mask.unsqueeze(1),
1288
+ -1000,
1289
+ )
1290
+
1291
+ # We use our own version of softmax, defined in scaling.py, which should
1292
+ # save a little of the memory used in backprop by, if we are in
1293
+ # automatic mixed precision mode (amp / autocast), by only storing the
1294
+ # half-precision output for backprop purposes.
1295
+ attn_weights = softmax(attn_scores, dim=-1)
1296
+
1297
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
1298
+ pass
1299
+ elif random.random() < 0.001 and not self.training:
1300
+ self._print_attn_entropy(attn_weights)
1301
+
1302
+ attn_weights = nn.functional.dropout(
1303
+ attn_weights, p=self.dropout, training=self.training
1304
+ )
1305
+
1306
+ return attn_weights
1307
+
1308
+ def _print_attn_entropy(self, attn_weights: Tensor):
1309
+ # attn_weights: (num_heads, batch_size, seq_len, seq_len)
1310
+ (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
1311
+
1312
+ with torch.no_grad():
1313
+ with torch.amp.autocast("cuda", enabled=False):
1314
+ attn_weights = attn_weights.to(torch.float32)
1315
+ attn_weights_entropy = (
1316
+ -((attn_weights + 1.0e-20).log() * attn_weights)
1317
+ .sum(dim=-1)
1318
+ .mean(dim=(1, 2))
1319
+ )
1320
+ logging.debug(
1321
+ f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}"
1322
+ )
1323
+
1324
+
1325
+ class SelfAttention(nn.Module):
1326
+ """
1327
+ The simplest possible attention module. This one works with already-computed
1328
+ attention weights, e.g. as computed by RelPositionMultiheadAttentionWeights.
1329
+
1330
+ Args:
1331
+ embed_dim: the input and output embedding dimension
1332
+ num_heads: the number of attention heads
1333
+ value_head_dim: the value dimension per head
1334
+ """
1335
+
1336
+ def __init__(
1337
+ self,
1338
+ embed_dim: int,
1339
+ num_heads: int,
1340
+ value_head_dim: int,
1341
+ ) -> None:
1342
+ super().__init__()
1343
+ self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True)
1344
+
1345
+ self.out_proj = ScaledLinear(
1346
+ num_heads * value_head_dim,
1347
+ embed_dim,
1348
+ bias=True,
1349
+ initial_scale=0.05,
1350
+ )
1351
+
1352
+ self.whiten = Whiten(
1353
+ num_groups=1,
1354
+ whitening_limit=_whitening_schedule(7.5, ratio=3.0),
1355
+ prob=(0.025, 0.25),
1356
+ grad_scale=0.01,
1357
+ )
1358
+
1359
+ def forward(
1360
+ self,
1361
+ x: Tensor,
1362
+ attn_weights: Tensor,
1363
+ ) -> Tensor:
1364
+ """
1365
+ Args:
1366
+ x: input tensor, of shape (seq_len, batch_size, embed_dim)
1367
+ attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
1368
+ with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect
1369
+ attn_weights.sum(dim=-1) == 1.
1370
+ Returns:
1371
+ a tensor with the same shape as x.
1372
+ """
1373
+ (seq_len, batch_size, embed_dim) = x.shape
1374
+ num_heads = attn_weights.shape[0]
1375
+ assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
1376
+
1377
+ x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim)
1378
+ x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
1379
+ # now x: (num_heads, batch_size, seq_len, value_head_dim)
1380
+ value_head_dim = x.shape[-1]
1381
+
1382
+ # todo: see whether there is benefit in overriding matmul
1383
+ x = torch.matmul(attn_weights, x)
1384
+ # v: (num_heads, batch_size, seq_len, value_head_dim)
1385
+
1386
+ x = (
1387
+ x.permute(2, 1, 0, 3)
1388
+ .contiguous()
1389
+ .view(seq_len, batch_size, num_heads * value_head_dim)
1390
+ )
1391
+
1392
+ # returned value is of shape (seq_len, batch_size, embed_dim), like the input.
1393
+ x = self.out_proj(x)
1394
+ x = self.whiten(x)
1395
+
1396
+ return x
1397
+
1398
+
1399
+ class FeedforwardModule(nn.Module):
1400
+ """Feedforward module in TTSZipformer model."""
1401
+
1402
+ def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike):
1403
+ super(FeedforwardModule, self).__init__()
1404
+ self.in_proj = nn.Linear(embed_dim, feedforward_dim)
1405
+
1406
+ self.hidden_balancer = Balancer(
1407
+ feedforward_dim,
1408
+ channel_dim=-1,
1409
+ min_positive=0.3,
1410
+ max_positive=1.0,
1411
+ min_abs=0.75,
1412
+ max_abs=5.0,
1413
+ )
1414
+
1415
+ # shared_dim=0 means we share the dropout mask along the time axis
1416
+ self.out_proj = ActivationDropoutAndLinear(
1417
+ feedforward_dim,
1418
+ embed_dim,
1419
+ activation="SwooshL",
1420
+ dropout_p=dropout,
1421
+ dropout_shared_dim=0,
1422
+ bias=True,
1423
+ initial_scale=0.1,
1424
+ )
1425
+
1426
+ self.out_whiten = Whiten(
1427
+ num_groups=1,
1428
+ whitening_limit=_whitening_schedule(7.5),
1429
+ prob=(0.025, 0.25),
1430
+ grad_scale=0.01,
1431
+ )
1432
+
1433
+ def forward(self, x: Tensor):
1434
+ x = self.in_proj(x)
1435
+ x = self.hidden_balancer(x)
1436
+ # out_proj contains SwooshL activation, then dropout, then linear.
1437
+ x = self.out_proj(x)
1438
+ x = self.out_whiten(x)
1439
+ return x
1440
+
1441
+
1442
+ class NonlinAttention(nn.Module):
1443
+ """This is like the ConvolutionModule, but refactored so that we use multiplication
1444
+ by attention weights (borrowed from the attention module) in place of actual
1445
+ convolution. We also took out the second nonlinearity, the one after the
1446
+ attention mechanism.
1447
+
1448
+ Args:
1449
+ channels (int): The number of channels of conv layers.
1450
+ """
1451
+
1452
+ def __init__(
1453
+ self,
1454
+ channels: int,
1455
+ hidden_channels: int,
1456
+ ) -> None:
1457
+ super().__init__()
1458
+
1459
+ self.hidden_channels = hidden_channels
1460
+
1461
+ self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True)
1462
+
1463
+ # balancer that goes before the sigmoid. Have quite a large min_abs value, at
1464
+ # 2.0, because we noticed that well-trained instances of this module have
1465
+ # abs-value before the sigmoid starting from about 3, and poorly-trained
1466
+ # instances of the module have smaller abs values before the sigmoid.
1467
+ self.balancer = Balancer(
1468
+ hidden_channels,
1469
+ channel_dim=-1,
1470
+ min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)),
1471
+ max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)),
1472
+ min_abs=0.5,
1473
+ max_abs=5.0,
1474
+ )
1475
+ self.tanh = nn.Tanh()
1476
+
1477
+ self.identity1 = Identity() # for diagnostics.
1478
+ self.identity2 = Identity() # for diagnostics.
1479
+ self.identity3 = Identity() # for diagnostics.
1480
+
1481
+ self.out_proj = ScaledLinear(
1482
+ hidden_channels, channels, bias=True, initial_scale=0.05
1483
+ )
1484
+
1485
+ self.whiten1 = Whiten(
1486
+ num_groups=1,
1487
+ whitening_limit=_whitening_schedule(5.0),
1488
+ prob=(0.025, 0.25),
1489
+ grad_scale=0.01,
1490
+ )
1491
+
1492
+ self.whiten2 = Whiten(
1493
+ num_groups=1,
1494
+ whitening_limit=_whitening_schedule(5.0, ratio=3.0),
1495
+ prob=(0.025, 0.25),
1496
+ grad_scale=0.01,
1497
+ )
1498
+
1499
+ def forward(
1500
+ self,
1501
+ x: Tensor,
1502
+ attn_weights: Tensor,
1503
+ ) -> Tensor:
1504
+ """.
1505
+ Args:
1506
+ x: a Tensor of shape (seq_len, batch_size, num_channels)
1507
+ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
1508
+ Returns:
1509
+ a Tensor with the same shape as x
1510
+ """
1511
+ x = self.in_proj(x)
1512
+
1513
+ (seq_len, batch_size, _) = x.shape
1514
+ hidden_channels = self.hidden_channels
1515
+
1516
+ s, x, y = x.chunk(3, dim=2)
1517
+
1518
+ # s will go through tanh.
1519
+
1520
+ s = self.balancer(s)
1521
+ s = self.tanh(s)
1522
+
1523
+ s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels)
1524
+ x = self.whiten1(x)
1525
+ x = x * s
1526
+ x = self.identity1(x) # diagnostics only, it's the identity.
1527
+
1528
+ (seq_len, batch_size, embed_dim) = x.shape
1529
+ num_heads = attn_weights.shape[0]
1530
+ assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
1531
+
1532
+ x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
1533
+ # now x: (num_heads, batch_size, seq_len, head_dim)
1534
+ x = torch.matmul(attn_weights, x)
1535
+ # now x: (num_heads, batch_size, seq_len, head_dim)
1536
+ x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
1537
+
1538
+ y = self.identity2(y)
1539
+ x = x * y
1540
+ x = self.identity3(x)
1541
+
1542
+ x = self.out_proj(x)
1543
+ x = self.whiten2(x)
1544
+ return x
1545
+
1546
+
1547
+ class ConvolutionModule(nn.Module):
1548
+ """ConvolutionModule in Zipformer2 model.
1549
+
1550
+ Args:
1551
+ channels (int): The number of channels of conv layers.
1552
+ kernel_size (int): Kernerl size of conv layers.
1553
+ bias (bool): Whether to use bias in conv layers (default=True).
1554
+
1555
+ """
1556
+
1557
+ def __init__(
1558
+ self,
1559
+ channels: int,
1560
+ kernel_size: int,
1561
+ ) -> None:
1562
+ """Construct a ConvolutionModule object."""
1563
+ super(ConvolutionModule, self).__init__()
1564
+ # kernerl_size should be a odd number for 'SAME' padding
1565
+ assert (kernel_size - 1) % 2 == 0
1566
+
1567
+ bottleneck_dim = channels
1568
+
1569
+ self.in_proj = nn.Linear(
1570
+ channels,
1571
+ 2 * bottleneck_dim,
1572
+ )
1573
+ # the gradients on in_proj are a little noisy, likely to do with the
1574
+ # sigmoid in glu.
1575
+
1576
+ # after in_proj we put x through a gated linear unit (nn.functional.glu). For
1577
+ # most layers the normal rms value of channels of x seems to be in the range 1
1578
+ # to 4, but sometimes, for some reason, for layer 0 the rms ends up being very
1579
+ # large, between 50 and 100 for different channels. This will cause very peaky
1580
+ # and sparse derivatives for the sigmoid gating function, which will tend to
1581
+ # make the loss function not learn effectively. (for most layers the average
1582
+ # absolute values are in the range 0.5..9.0, and the average p(x>0), i.e.
1583
+ # positive proportion, at the output of pointwise_conv1.output is around 0.35 to
1584
+ # 0.45 for different layers, which likely breaks down as 0.5 for the "linear"
1585
+ # half and 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that
1586
+ # if we constrain the rms values to a reasonable range via a constraint of
1587
+ # max_abs=10.0, it will be in a better position to start learning something,
1588
+ # i.e. to latch onto the correct range.
1589
+ self.balancer1 = Balancer(
1590
+ bottleneck_dim,
1591
+ channel_dim=-1,
1592
+ min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)),
1593
+ max_positive=1.0,
1594
+ min_abs=1.5,
1595
+ max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0),
1596
+ )
1597
+
1598
+ self.activation1 = Identity() # for diagnostics
1599
+
1600
+ self.sigmoid = nn.Sigmoid()
1601
+
1602
+ self.activation2 = Identity() # for diagnostics
1603
+
1604
+ assert kernel_size % 2 == 1
1605
+
1606
+ self.depthwise_conv = nn.Conv1d(
1607
+ in_channels=bottleneck_dim,
1608
+ out_channels=bottleneck_dim,
1609
+ groups=bottleneck_dim,
1610
+ kernel_size=kernel_size,
1611
+ padding=kernel_size // 2,
1612
+ )
1613
+
1614
+ self.balancer2 = Balancer(
1615
+ bottleneck_dim,
1616
+ channel_dim=1,
1617
+ min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
1618
+ max_positive=1.0,
1619
+ min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)),
1620
+ max_abs=10.0,
1621
+ )
1622
+
1623
+ self.whiten = Whiten(
1624
+ num_groups=1,
1625
+ whitening_limit=_whitening_schedule(7.5),
1626
+ prob=(0.025, 0.25),
1627
+ grad_scale=0.01,
1628
+ )
1629
+
1630
+ self.out_proj = ActivationDropoutAndLinear(
1631
+ bottleneck_dim,
1632
+ channels,
1633
+ activation="SwooshR",
1634
+ dropout_p=0.0,
1635
+ initial_scale=0.05,
1636
+ )
1637
+
1638
+ def forward(
1639
+ self,
1640
+ x: Tensor,
1641
+ src_key_padding_mask: Optional[Tensor] = None,
1642
+ ) -> Tensor:
1643
+ """Compute convolution module.
1644
+
1645
+ Args:
1646
+ x: Input tensor (#time, batch, channels).
1647
+ src_key_padding_mask: the mask for the src keys per batch (optional):
1648
+ (batch, #time), contains True in masked positions.
1649
+
1650
+ Returns:
1651
+ Tensor: Output tensor (#time, batch, channels).
1652
+
1653
+ """
1654
+
1655
+ x = self.in_proj(x) # (time, batch, 2*channels)
1656
+
1657
+ x, s = x.chunk(2, dim=2)
1658
+ s = self.balancer1(s)
1659
+ s = self.sigmoid(s)
1660
+ x = self.activation1(x) # identity.
1661
+ x = x * s
1662
+ x = self.activation2(x) # identity
1663
+
1664
+ # (time, batch, channels)
1665
+
1666
+ # exchange the temporal dimension and the feature dimension
1667
+ x = x.permute(1, 2, 0) # (#batch, channels, time).
1668
+
1669
+ if src_key_padding_mask is not None:
1670
+ x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
1671
+
1672
+ x = self.depthwise_conv(x)
1673
+
1674
+ x = self.balancer2(x)
1675
+ x = x.permute(2, 0, 1) # (time, batch, channels)
1676
+
1677
+ x = self.whiten(x) # (time, batch, channels)
1678
+ x = self.out_proj(x) # (time, batch, channels)
1679
+
1680
+ return x
zipvoice/models/modules/zipformer_two_stream.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import math
19
+ from typing import Optional, Tuple, Union
20
+
21
+ import torch
22
+ from torch import Tensor, nn
23
+
24
+ from zipvoice.models.modules.scaling import FloatLike, ScheduledFloat, SwooshR
25
+ from zipvoice.models.modules.zipformer import (
26
+ DownsampledZipformer2Encoder,
27
+ TTSZipformer,
28
+ Zipformer2Encoder,
29
+ Zipformer2EncoderLayer,
30
+ )
31
+
32
+
33
+ def timestep_embedding(timesteps, dim, max_period=10000):
34
+ """Create sinusoidal timestep embeddings.
35
+
36
+ :param timesteps: shape of (N) or (N, T)
37
+ :param dim: the dimension of the output.
38
+ :param max_period: controls the minimum frequency of the embeddings.
39
+ :return: an Tensor of positional embeddings. shape of (N, dim) or (T, N, dim)
40
+ """
41
+ half = dim // 2
42
+ freqs = torch.exp(
43
+ -math.log(max_period)
44
+ * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device)
45
+ / half
46
+ )
47
+
48
+ if timesteps.dim() == 2:
49
+ timesteps = timesteps.transpose(0, 1) # (N, T) -> (T, N)
50
+
51
+ args = timesteps[..., None].float() * freqs[None]
52
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
53
+ if dim % 2:
54
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[..., :1])], dim=-1)
55
+ return embedding
56
+
57
+
58
+ class TTSZipformerTwoStream(TTSZipformer):
59
+ """
60
+ Args:
61
+
62
+ Note: all "int or Tuple[int]" arguments below will be treated as lists of the same
63
+ length as downsampling_factor if they are single ints or one-element tuples.
64
+ The length of downsampling_factor defines the number of stacks.
65
+
66
+ downsampling_factor (Tuple[int]): downsampling factor for each encoder stack.
67
+ Note: this is in addition to the downsampling factor of 2 that is applied in
68
+ the frontend (self.encoder_embed).
69
+ encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks,
70
+ one per encoder stack.
71
+ num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
72
+ query_head_dim (int or Tuple[int]): dimension of query and key per attention
73
+ head: per stack, if a tuple..
74
+ pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection
75
+ per attention head
76
+ value_head_dim (int or Tuple[int]): dimension of value in each attention head
77
+ num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism.
78
+ Must be at least 4.
79
+ feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules
80
+ cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module
81
+
82
+ pos_dim (int): the dimension of each positional-encoding vector prior to
83
+ projection, e.g. 128.
84
+
85
+ dropout (float): dropout rate
86
+ warmup_batches (float): number of batches to warm up over; this controls
87
+ dropout of encoder layers.
88
+ use_time_embed: (bool): if True, do not take time embedding as additional input.
89
+ time_embed_dim: (int): the dimension of the time embedding.
90
+ """
91
+
92
+ def __init__(
93
+ self,
94
+ in_dim: Tuple[int],
95
+ out_dim: Tuple[int],
96
+ downsampling_factor: Tuple[int] = (2, 4),
97
+ num_encoder_layers: Union[int, Tuple[int]] = 4,
98
+ cnn_module_kernel: Union[int, Tuple[int]] = 31,
99
+ encoder_dim: int = 384,
100
+ query_head_dim: int = 24,
101
+ pos_head_dim: int = 4,
102
+ value_head_dim: int = 12,
103
+ num_heads: int = 8,
104
+ feedforward_dim: int = 1536,
105
+ pos_dim: int = 192,
106
+ dropout: FloatLike = None, # see code below for default
107
+ warmup_batches: float = 4000.0,
108
+ use_time_embed: bool = True,
109
+ time_embed_dim: int = 192,
110
+ use_conv: bool = True,
111
+ ) -> None:
112
+ nn.Module.__init__(self)
113
+
114
+ if dropout is None:
115
+ dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1))
116
+ if isinstance(downsampling_factor, int):
117
+ downsampling_factor = (downsampling_factor,)
118
+
119
+ def _to_tuple(x):
120
+ """Converts a single int or a 1-tuple of an int to a tuple with the same
121
+ length as downsampling_factor"""
122
+ if isinstance(x, int):
123
+ x = (x,)
124
+ if len(x) == 1:
125
+ x = x * len(downsampling_factor)
126
+ else:
127
+ assert len(x) == len(downsampling_factor) and isinstance(x[0], int)
128
+ return x
129
+
130
+ def _assert_downsampling_factor(factors):
131
+ """assert downsampling_factor follows u-net style"""
132
+ assert factors[0] == 1 and factors[-1] == 1
133
+
134
+ for i in range(1, len(factors) // 2 + 1):
135
+ assert factors[i] == factors[i - 1] * 2
136
+
137
+ for i in range(len(factors) // 2 + 1, len(factors)):
138
+ assert factors[i] * 2 == factors[i - 1]
139
+
140
+ _assert_downsampling_factor(downsampling_factor)
141
+ self.downsampling_factor = downsampling_factor # tuple
142
+ num_encoder_layers = _to_tuple(num_encoder_layers)
143
+ self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
144
+ self.encoder_dim = encoder_dim
145
+ self.num_encoder_layers = num_encoder_layers
146
+ self.query_head_dim = query_head_dim
147
+ self.value_head_dim = value_head_dim
148
+ self.num_heads = num_heads
149
+
150
+ self.use_time_embed = use_time_embed
151
+
152
+ self.time_embed_dim = time_embed_dim
153
+ if self.use_time_embed:
154
+ assert time_embed_dim != -1
155
+ else:
156
+ time_embed_dim = -1
157
+
158
+ assert len(in_dim) == len(out_dim) == 2
159
+
160
+ self.in_dim = in_dim
161
+ self.in_proj = nn.ModuleList(
162
+ [nn.Linear(in_dim[0], encoder_dim), nn.Linear(in_dim[1], encoder_dim)]
163
+ )
164
+ self.out_dim = out_dim
165
+ self.out_proj = nn.ModuleList(
166
+ [nn.Linear(encoder_dim, out_dim[0]), nn.Linear(encoder_dim, out_dim[1])]
167
+ )
168
+
169
+ # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder
170
+ encoders = []
171
+
172
+ num_encoders = len(downsampling_factor)
173
+ for i in range(num_encoders):
174
+ encoder_layer = Zipformer2EncoderLayer(
175
+ embed_dim=encoder_dim,
176
+ pos_dim=pos_dim,
177
+ num_heads=num_heads,
178
+ query_head_dim=query_head_dim,
179
+ pos_head_dim=pos_head_dim,
180
+ value_head_dim=value_head_dim,
181
+ feedforward_dim=feedforward_dim,
182
+ use_conv=use_conv,
183
+ cnn_module_kernel=cnn_module_kernel[i],
184
+ dropout=dropout,
185
+ )
186
+
187
+ # For the segment of the warmup period, we let the Conv2dSubsampling
188
+ # layer learn something. Then we start to warm up the other encoders.
189
+ encoder = Zipformer2Encoder(
190
+ encoder_layer,
191
+ num_encoder_layers[i],
192
+ embed_dim=encoder_dim,
193
+ time_embed_dim=time_embed_dim,
194
+ pos_dim=pos_dim,
195
+ warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
196
+ warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
197
+ final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
198
+ )
199
+
200
+ if downsampling_factor[i] != 1:
201
+ encoder = DownsampledZipformer2Encoder(
202
+ encoder,
203
+ dim=encoder_dim,
204
+ downsample=downsampling_factor[i],
205
+ )
206
+
207
+ encoders.append(encoder)
208
+
209
+ self.encoders = nn.ModuleList(encoders)
210
+ if self.use_time_embed:
211
+ self.time_embed = nn.Sequential(
212
+ nn.Linear(time_embed_dim, time_embed_dim * 2),
213
+ SwooshR(),
214
+ nn.Linear(time_embed_dim * 2, time_embed_dim),
215
+ )
216
+ else:
217
+ self.time_embed = None
218
+
219
+ def forward(
220
+ self,
221
+ x: Tensor,
222
+ t: Optional[Tensor] = None,
223
+ padding_mask: Optional[Tensor] = None,
224
+ ) -> Tuple[Tensor, Tensor]:
225
+ """
226
+ Args:
227
+ x:
228
+ The input tensor. Its shape is (batch_size, seq_len, feature_dim).
229
+ t:
230
+ A t tensor of shape (batch_size,) or (batch_size, seq_len)
231
+ padding_mask:
232
+ The mask for padding, of shape (batch_size, seq_len); True means
233
+ masked position. May be None.
234
+ Returns:
235
+ Return the output embeddings. its shape is
236
+ (batch_size, output_seq_len, encoder_dim)
237
+ """
238
+ assert x.size(2) in self.in_dim, f"{x.size(2)} in {self.in_dim}"
239
+ if x.size(2) == self.in_dim[0]:
240
+ index = 0
241
+ else:
242
+ index = 1
243
+ x = x.permute(1, 0, 2)
244
+ x = self.in_proj[index](x)
245
+
246
+ if t is not None:
247
+ assert t.dim() == 1 or t.dim() == 2, t.shape
248
+ time_emb = timestep_embedding(t, self.time_embed_dim)
249
+ time_emb = self.time_embed(time_emb)
250
+ else:
251
+ time_emb = None
252
+
253
+ attn_mask = None
254
+
255
+ for i, module in enumerate(self.encoders):
256
+ x = module(
257
+ x,
258
+ time_emb=time_emb,
259
+ src_key_padding_mask=padding_mask,
260
+ attn_mask=attn_mask,
261
+ )
262
+ x = self.out_proj[index](x)
263
+ x = x.permute(1, 0, 2)
264
+ return x
zipvoice/models/zipvoice.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Xiaomi Corp. (authors: Wei Kang
2
+ # Han Zhu)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from typing import List, Optional
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ from torch.nn.parallel import DistributedDataParallel as DDP
23
+
24
+ from zipvoice.models.modules.solver import EulerSolver
25
+ from zipvoice.models.modules.zipformer import TTSZipformer
26
+ from zipvoice.utils.common import (
27
+ condition_time_mask,
28
+ get_tokens_index,
29
+ make_pad_mask,
30
+ pad_labels,
31
+ prepare_avg_tokens_durations,
32
+ )
33
+
34
+
35
+ class ZipVoice(nn.Module):
36
+ """The ZipVoice model."""
37
+
38
+ def __init__(
39
+ self,
40
+ fm_decoder_downsampling_factor: List[int] = [1, 2, 4, 2, 1],
41
+ fm_decoder_num_layers: List[int] = [2, 2, 4, 4, 4],
42
+ fm_decoder_cnn_module_kernel: List[int] = [31, 15, 7, 15, 31],
43
+ fm_decoder_feedforward_dim: int = 1536,
44
+ fm_decoder_num_heads: int = 4,
45
+ fm_decoder_dim: int = 512,
46
+ text_encoder_num_layers: int = 4,
47
+ text_encoder_feedforward_dim: int = 512,
48
+ text_encoder_cnn_module_kernel: int = 9,
49
+ text_encoder_num_heads: int = 4,
50
+ text_encoder_dim: int = 192,
51
+ time_embed_dim: int = 192,
52
+ text_embed_dim: int = 192,
53
+ query_head_dim: int = 32,
54
+ value_head_dim: int = 12,
55
+ pos_head_dim: int = 4,
56
+ pos_dim: int = 48,
57
+ feat_dim: int = 100,
58
+ vocab_size: int = 26,
59
+ pad_id: int = 0,
60
+ ):
61
+ """
62
+ Initialize the model with specified configuration parameters.
63
+
64
+ Args:
65
+ fm_decoder_downsampling_factor: List of downsampling factors for each layer
66
+ in the flow-matching decoder.
67
+ fm_decoder_num_layers: List of the number of layers for each block in the
68
+ flow-matching decoder.
69
+ fm_decoder_cnn_module_kernel: List of kernel sizes for CNN modules in the
70
+ flow-matching decoder.
71
+ fm_decoder_feedforward_dim: Dimension of the feedforward network in the
72
+ flow-matching decoder.
73
+ fm_decoder_num_heads: Number of attention heads in the flow-matching
74
+ decoder.
75
+ fm_decoder_dim: Hidden dimension of the flow-matching decoder.
76
+ text_encoder_num_layers: Number of layers in the text encoder.
77
+ text_encoder_feedforward_dim: Dimension of the feedforward network in the
78
+ text encoder.
79
+ text_encoder_cnn_module_kernel: Kernel size for the CNN module in the
80
+ text encoder.
81
+ text_encoder_num_heads: Number of attention heads in the text encoder.
82
+ text_encoder_dim: Hidden dimension of the text encoder.
83
+ time_embed_dim: Dimension of the time embedding.
84
+ text_embed_dim: Dimension of the text embedding.
85
+ query_head_dim: Dimension of the query attention head.
86
+ value_head_dim: Dimension of the value attention head.
87
+ pos_head_dim: Dimension of the position attention head.
88
+ pos_dim: Dimension of the positional encoding.
89
+ feat_dim: Dimension of the acoustic features.
90
+ vocab_size: Size of the vocabulary.
91
+ pad_id: ID used for padding tokens.
92
+ """
93
+ super().__init__()
94
+
95
+ self.fm_decoder = TTSZipformer(
96
+ in_dim=feat_dim * 3,
97
+ out_dim=feat_dim,
98
+ downsampling_factor=fm_decoder_downsampling_factor,
99
+ num_encoder_layers=fm_decoder_num_layers,
100
+ cnn_module_kernel=fm_decoder_cnn_module_kernel,
101
+ encoder_dim=fm_decoder_dim,
102
+ feedforward_dim=fm_decoder_feedforward_dim,
103
+ num_heads=fm_decoder_num_heads,
104
+ query_head_dim=query_head_dim,
105
+ pos_head_dim=pos_head_dim,
106
+ value_head_dim=value_head_dim,
107
+ pos_dim=pos_dim,
108
+ use_time_embed=True,
109
+ time_embed_dim=time_embed_dim,
110
+ )
111
+
112
+ self.text_encoder = TTSZipformer(
113
+ in_dim=text_embed_dim,
114
+ out_dim=feat_dim,
115
+ downsampling_factor=1,
116
+ num_encoder_layers=text_encoder_num_layers,
117
+ cnn_module_kernel=text_encoder_cnn_module_kernel,
118
+ encoder_dim=text_encoder_dim,
119
+ feedforward_dim=text_encoder_feedforward_dim,
120
+ num_heads=text_encoder_num_heads,
121
+ query_head_dim=query_head_dim,
122
+ pos_head_dim=pos_head_dim,
123
+ value_head_dim=value_head_dim,
124
+ pos_dim=pos_dim,
125
+ use_time_embed=False,
126
+ )
127
+
128
+ self.feat_dim = feat_dim
129
+ self.text_embed_dim = text_embed_dim
130
+ self.pad_id = pad_id
131
+
132
+ self.embed = nn.Embedding(vocab_size, text_embed_dim)
133
+ self.solver = EulerSolver(self, func_name="forward_fm_decoder")
134
+
135
+ def forward_fm_decoder(
136
+ self,
137
+ t: torch.Tensor,
138
+ xt: torch.Tensor,
139
+ text_condition: torch.Tensor,
140
+ speech_condition: torch.Tensor,
141
+ padding_mask: Optional[torch.Tensor] = None,
142
+ guidance_scale: Optional[torch.Tensor] = None,
143
+ ) -> torch.Tensor:
144
+ """Compute velocity.
145
+ Args:
146
+ t: A tensor of shape (N, 1, 1) or a tensor of a float,
147
+ in the range of (0, 1).
148
+ xt: the input of the current timestep, including condition
149
+ embeddings and noisy acoustic features.
150
+ text_condition: the text condition embeddings, with the
151
+ shape (batch, seq_len, emb_dim).
152
+ speech_condition: the speech condition embeddings, with the
153
+ shape (batch, seq_len, emb_dim).
154
+ padding_mask: The mask for padding, True means masked
155
+ position, with the shape (N, T).
156
+ guidance_scale: The guidance scale in classifier-free guidance,
157
+ which is a tensor of shape (N, 1, 1) or a tensor of a float.
158
+
159
+ Returns:
160
+ predicted velocity, with the shape (batch, seq_len, emb_dim).
161
+ """
162
+
163
+ xt = torch.cat([xt, text_condition, speech_condition], dim=2)
164
+
165
+ assert t.dim() in (0, 3)
166
+ # Handle t with the shape (N, 1, 1):
167
+ # squeeze the last dimension if it's size is 1.
168
+ while t.dim() > 1 and t.size(-1) == 1:
169
+ t = t.squeeze(-1)
170
+ # Handle t with a single value: expand to the size of batch size.
171
+ if t.dim() == 0:
172
+ t = t.repeat(xt.shape[0])
173
+
174
+ if guidance_scale is not None:
175
+ while guidance_scale.dim() > 1 and guidance_scale.size(-1) == 1:
176
+ guidance_scale = guidance_scale.squeeze(-1)
177
+ if guidance_scale.dim() == 0:
178
+ guidance_scale = guidance_scale.repeat(xt.shape[0])
179
+
180
+ vt = self.fm_decoder(
181
+ x=xt, t=t, padding_mask=padding_mask, guidance_scale=guidance_scale
182
+ )
183
+ else:
184
+ vt = self.fm_decoder(x=xt, t=t, padding_mask=padding_mask)
185
+ return vt
186
+
187
+ def forward_text_embed(
188
+ self,
189
+ tokens: List[List[int]],
190
+ ):
191
+ """
192
+ Get the text embeddings.
193
+ Args:
194
+ tokens: a list of list of token ids.
195
+ Returns:
196
+ embed: the text embeddings, shape (batch, seq_len, emb_dim).
197
+ tokens_lens: the length of each token sequence, shape (batch,).
198
+ """
199
+ device = (
200
+ self.device if isinstance(self, DDP) else next(self.parameters()).device
201
+ )
202
+ tokens_padded = pad_labels(tokens, pad_id=self.pad_id, device=device) # (B, S)
203
+ embed = self.embed(tokens_padded) # (B, S, C)
204
+ tokens_lens = torch.tensor(
205
+ [len(token) for token in tokens], dtype=torch.int64, device=device
206
+ )
207
+ tokens_padding_mask = make_pad_mask(tokens_lens, embed.shape[1]) # (B, S)
208
+
209
+ embed = self.text_encoder(
210
+ x=embed, t=None, padding_mask=tokens_padding_mask
211
+ ) # (B, S, C)
212
+ return embed, tokens_lens
213
+
214
+ def forward_text_condition(
215
+ self,
216
+ embed: torch.Tensor,
217
+ tokens_lens: torch.Tensor,
218
+ features_lens: torch.Tensor,
219
+ ):
220
+ """
221
+ Get the text condition with the same length of the acoustic feature.
222
+ Args:
223
+ embed: the text embeddings, shape (batch, token_seq_len, emb_dim).
224
+ tokens_lens: the length of each token sequence, shape (batch,).
225
+ features_lens: the length of each acoustic feature sequence,
226
+ shape (batch,).
227
+ Returns:
228
+ text_condition: the text condition, shape
229
+ (batch, feature_seq_len, emb_dim).
230
+ padding_mask: the padding mask of text condition, shape
231
+ (batch, feature_seq_len).
232
+ """
233
+
234
+ num_frames = int(features_lens.max())
235
+
236
+ padding_mask = make_pad_mask(features_lens, max_len=num_frames) # (B, T)
237
+
238
+ tokens_durations = prepare_avg_tokens_durations(features_lens, tokens_lens)
239
+
240
+ tokens_index = get_tokens_index(tokens_durations, num_frames).to(
241
+ embed.device
242
+ ) # (B, T)
243
+
244
+ text_condition = torch.gather(
245
+ embed,
246
+ dim=1,
247
+ index=tokens_index.unsqueeze(-1).expand(
248
+ embed.size(0), num_frames, embed.size(-1)
249
+ ),
250
+ ) # (B, T, F)
251
+ return text_condition, padding_mask
252
+
253
+ def forward_text_train(
254
+ self,
255
+ tokens: List[List[int]],
256
+ features_lens: torch.Tensor,
257
+ ):
258
+ """
259
+ Process text for training, given text tokens and real feature lengths.
260
+ """
261
+ embed, tokens_lens = self.forward_text_embed(tokens)
262
+ text_condition, padding_mask = self.forward_text_condition(
263
+ embed, tokens_lens, features_lens
264
+ )
265
+ return (
266
+ text_condition,
267
+ padding_mask,
268
+ )
269
+
270
+ def forward_text_inference_gt_duration(
271
+ self,
272
+ tokens: List[List[int]],
273
+ features_lens: torch.Tensor,
274
+ prompt_tokens: List[List[int]],
275
+ prompt_features_lens: torch.Tensor,
276
+ ):
277
+ """
278
+ Process text for inference, given text tokens, real feature lengths and prompts.
279
+ """
280
+ tokens = [
281
+ prompt_token + token for prompt_token, token in zip(prompt_tokens, tokens)
282
+ ]
283
+ features_lens = prompt_features_lens + features_lens
284
+ embed, tokens_lens = self.forward_text_embed(tokens)
285
+ text_condition, padding_mask = self.forward_text_condition(
286
+ embed, tokens_lens, features_lens
287
+ )
288
+ return text_condition, padding_mask
289
+
290
+ def forward_text_inference_ratio_duration(
291
+ self,
292
+ tokens: List[List[int]],
293
+ prompt_tokens: List[List[int]],
294
+ prompt_features_lens: torch.Tensor,
295
+ speed: float,
296
+ ):
297
+ """
298
+ Process text for inference, given text tokens and prompts,
299
+ feature lengths are predicted with the ratio of token numbers.
300
+ """
301
+ device = (
302
+ self.device if isinstance(self, DDP) else next(self.parameters()).device
303
+ )
304
+
305
+ cat_tokens = [
306
+ prompt_token + token for prompt_token, token in zip(prompt_tokens, tokens)
307
+ ]
308
+
309
+ prompt_tokens_lens = torch.tensor(
310
+ [len(token) for token in prompt_tokens],
311
+ dtype=torch.int64,
312
+ device=device,
313
+ )
314
+
315
+ tokens_lens = torch.tensor(
316
+ [len(token) for token in tokens],
317
+ dtype=torch.int64,
318
+ device=device,
319
+ )
320
+
321
+ cat_embed, cat_tokens_lens = self.forward_text_embed(cat_tokens)
322
+
323
+ features_lens = prompt_features_lens + torch.ceil(
324
+ (prompt_features_lens / prompt_tokens_lens * tokens_lens / speed)
325
+ ).to(dtype=torch.int64)
326
+
327
+ text_condition, padding_mask = self.forward_text_condition(
328
+ cat_embed, cat_tokens_lens, features_lens
329
+ )
330
+ return text_condition, padding_mask
331
+
332
+ def forward(
333
+ self,
334
+ tokens: List[List[int]],
335
+ features: torch.Tensor,
336
+ features_lens: torch.Tensor,
337
+ noise: torch.Tensor,
338
+ t: torch.Tensor,
339
+ condition_drop_ratio: float = 0.0,
340
+ ) -> torch.Tensor:
341
+ """Forward pass of the model for training.
342
+ Args:
343
+ tokens: a list of list of token ids.
344
+ features: the acoustic features, with the shape (batch, seq_len, feat_dim).
345
+ features_lens: the length of each acoustic feature sequence, shape (batch,).
346
+ noise: the intitial noise, with the shape (batch, seq_len, feat_dim).
347
+ t: the time step, with the shape (batch, 1, 1).
348
+ condition_drop_ratio: the ratio of dropped text condition.
349
+ Returns:
350
+ fm_loss: the flow-matching loss.
351
+ """
352
+
353
+ (text_condition, padding_mask,) = self.forward_text_train(
354
+ tokens=tokens,
355
+ features_lens=features_lens,
356
+ )
357
+
358
+ speech_condition_mask = condition_time_mask(
359
+ features_lens=features_lens,
360
+ mask_percent=(0.7, 1.0),
361
+ max_len=features.size(1),
362
+ )
363
+ speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features)
364
+
365
+ if condition_drop_ratio > 0.0:
366
+ drop_mask = (
367
+ torch.rand(text_condition.size(0), 1, 1).to(text_condition.device)
368
+ > condition_drop_ratio
369
+ )
370
+ text_condition = text_condition * drop_mask
371
+
372
+ xt = features * t + noise * (1 - t)
373
+ ut = features - noise # (B, T, F)
374
+
375
+ vt = self.forward_fm_decoder(
376
+ t=t,
377
+ xt=xt,
378
+ text_condition=text_condition,
379
+ speech_condition=speech_condition,
380
+ padding_mask=padding_mask,
381
+ )
382
+
383
+ loss_mask = speech_condition_mask & (~padding_mask)
384
+ fm_loss = torch.mean((vt[loss_mask] - ut[loss_mask]) ** 2)
385
+
386
+ return fm_loss
387
+
388
+ def sample(
389
+ self,
390
+ tokens: List[List[int]],
391
+ prompt_tokens: List[List[int]],
392
+ prompt_features: torch.Tensor,
393
+ prompt_features_lens: torch.Tensor,
394
+ features_lens: Optional[torch.Tensor] = None,
395
+ speed: float = 1.0,
396
+ t_shift: float = 1.0,
397
+ duration: str = "predict",
398
+ num_step: int = 5,
399
+ guidance_scale: float = 0.5,
400
+ ) -> torch.Tensor:
401
+ """
402
+ Generate acoustic features, given text tokens, prompts feature
403
+ and prompt transcription's text tokens.
404
+ Args:
405
+ tokens: a list of list of text tokens.
406
+ prompt_tokens: a list of list of prompt tokens.
407
+ prompt_features: the prompt feature with the shape
408
+ (batch_size, seq_len, feat_dim).
409
+ prompt_features_lens: the length of each prompt feature,
410
+ with the shape (batch_size,).
411
+ features_lens: the length of the predicted eature, with the
412
+ shape (batch_size,). It is used only when duration is "real".
413
+ duration: "real" or "predict". If "real", the predicted
414
+ feature length is given by features_lens.
415
+ num_step: the number of steps to use in the ODE solver.
416
+ guidance_scale: the guidance scale for classifier-free guidance.
417
+ """
418
+
419
+ assert duration in ["real", "predict"]
420
+
421
+ if duration == "predict":
422
+ (
423
+ text_condition,
424
+ padding_mask,
425
+ ) = self.forward_text_inference_ratio_duration(
426
+ tokens=tokens,
427
+ prompt_tokens=prompt_tokens,
428
+ prompt_features_lens=prompt_features_lens,
429
+ speed=speed,
430
+ )
431
+ else:
432
+ assert features_lens is not None
433
+ text_condition, padding_mask = self.forward_text_inference_gt_duration(
434
+ tokens=tokens,
435
+ features_lens=features_lens,
436
+ prompt_tokens=prompt_tokens,
437
+ prompt_features_lens=prompt_features_lens,
438
+ )
439
+ batch_size, num_frames, _ = text_condition.shape
440
+
441
+ speech_condition = torch.nn.functional.pad(
442
+ prompt_features, (0, 0, 0, num_frames - prompt_features.size(1))
443
+ ) # (B, T, F)
444
+
445
+ # False means speech condition positions.
446
+ speech_condition_mask = make_pad_mask(prompt_features_lens, num_frames)
447
+ speech_condition = torch.where(
448
+ speech_condition_mask.unsqueeze(-1),
449
+ torch.zeros_like(speech_condition),
450
+ speech_condition,
451
+ )
452
+
453
+ x0 = torch.randn(
454
+ batch_size,
455
+ num_frames,
456
+ prompt_features.size(-1),
457
+ device=text_condition.device,
458
+ )
459
+
460
+ x1 = self.solver.sample(
461
+ x=x0,
462
+ text_condition=text_condition,
463
+ speech_condition=speech_condition,
464
+ padding_mask=padding_mask,
465
+ num_step=num_step,
466
+ guidance_scale=guidance_scale,
467
+ t_shift=t_shift,
468
+ )
469
+ x1_wo_prompt_lens = (~padding_mask).sum(-1) - prompt_features_lens
470
+ x1_prompt = torch.zeros(
471
+ x1.size(0), prompt_features_lens.max(), x1.size(2), device=x1.device
472
+ )
473
+ x1_wo_prompt = torch.zeros(
474
+ x1.size(0), x1_wo_prompt_lens.max(), x1.size(2), device=x1.device
475
+ )
476
+ for i in range(x1.size(0)):
477
+ x1_wo_prompt[i, : x1_wo_prompt_lens[i], :] = x1[
478
+ i,
479
+ prompt_features_lens[i] : prompt_features_lens[i]
480
+ + x1_wo_prompt_lens[i],
481
+ ]
482
+ x1_prompt[i, : prompt_features_lens[i], :] = x1[
483
+ i, : prompt_features_lens[i]
484
+ ]
485
+
486
+ return x1_wo_prompt, x1_wo_prompt_lens, x1_prompt, prompt_features_lens
487
+
488
+ def sample_intermediate(
489
+ self,
490
+ tokens: List[List[int]],
491
+ features: torch.Tensor,
492
+ features_lens: torch.Tensor,
493
+ noise: torch.Tensor,
494
+ speech_condition_mask: torch.Tensor,
495
+ t_start: float,
496
+ t_end: float,
497
+ num_step: int = 1,
498
+ guidance_scale: torch.Tensor = None,
499
+ ) -> torch.Tensor:
500
+ """
501
+ Generate acoustic features in intermediate timesteps.
502
+ Args:
503
+ tokens: List of list of token ids.
504
+ features: The acoustic features, with the shape (batch, seq_len, feat_dim).
505
+ features_lens: The length of each acoustic feature sequence,
506
+ with the shape (batch,).
507
+ noise: The initial noise, with the shape (batch, seq_len, feat_dim).
508
+ speech_condition_mask: The mask for speech condition, True means
509
+ non-condition positions, with the shape (batch, seq_len).
510
+ t_start: The start timestep.
511
+ t_end: The end timestep.
512
+ num_step: The number of steps for sampling.
513
+ guidance_scale: The scale for classifier-free guidance inference,
514
+ with the shape (batch, 1, 1).
515
+ """
516
+ (text_condition, padding_mask,) = self.forward_text_train(
517
+ tokens=tokens,
518
+ features_lens=features_lens,
519
+ )
520
+
521
+ speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features)
522
+
523
+ x_t_end = self.solver.sample(
524
+ x=noise,
525
+ text_condition=text_condition,
526
+ speech_condition=speech_condition,
527
+ padding_mask=padding_mask,
528
+ num_step=num_step,
529
+ guidance_scale=guidance_scale,
530
+ t_start=t_start,
531
+ t_end=t_end,
532
+ )
533
+ x_t_end_lens = (~padding_mask).sum(-1)
534
+ return x_t_end, x_t_end_lens
zipvoice/models/zipvoice_dialog.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Xiaomi Corp. (authors: Han Zhu)
2
+ #
3
+ # See ../../../../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from typing import List
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ from torch.nn.parallel import DistributedDataParallel as DDP
22
+
23
+ from zipvoice.models.modules.zipformer_two_stream import TTSZipformerTwoStream
24
+ from zipvoice.models.zipvoice import ZipVoice
25
+ from zipvoice.utils.common import condition_time_mask_suffix, make_pad_mask, pad_labels
26
+
27
+
28
+ class ZipVoiceDialog(ZipVoice):
29
+ """The ZipVoice-Dialog model."""
30
+
31
+ def __init__(
32
+ self,
33
+ fm_decoder_downsampling_factor: List[int] = [1, 2, 4, 2, 1],
34
+ fm_decoder_num_layers: List[int] = [2, 2, 4, 4, 4],
35
+ fm_decoder_cnn_module_kernel: List[int] = [31, 15, 7, 15, 31],
36
+ fm_decoder_feedforward_dim: int = 1536,
37
+ fm_decoder_num_heads: int = 4,
38
+ fm_decoder_dim: int = 512,
39
+ text_encoder_num_layers: int = 4,
40
+ text_encoder_feedforward_dim: int = 512,
41
+ text_encoder_cnn_module_kernel: int = 9,
42
+ text_encoder_num_heads: int = 4,
43
+ text_encoder_dim: int = 192,
44
+ time_embed_dim: int = 192,
45
+ text_embed_dim: int = 192,
46
+ query_head_dim: int = 32,
47
+ value_head_dim: int = 12,
48
+ pos_head_dim: int = 4,
49
+ pos_dim: int = 48,
50
+ feat_dim: int = 100,
51
+ vocab_size: int = 26,
52
+ pad_id: int = 0,
53
+ spk_a_id: int = 360,
54
+ spk_b_id: int = 361,
55
+ ):
56
+ """
57
+ Initialize the model with specified configuration parameters.
58
+
59
+ Args:
60
+ fm_decoder_downsampling_factor: List of downsampling factors for each layer
61
+ in the flow-matching decoder.
62
+ fm_decoder_num_layers: List of the number of layers for each block in the
63
+ flow-matching decoder.
64
+ fm_decoder_cnn_module_kernel: List of kernel sizes for CNN modules in the
65
+ flow-matching decoder.
66
+ fm_decoder_feedforward_dim: Dimension of the feedforward network in the
67
+ flow-matching decoder.
68
+ fm_decoder_num_heads: Number of attention heads in the flow-matching
69
+ decoder.
70
+ fm_decoder_dim: Hidden dimension of the flow-matching decoder.
71
+ text_encoder_num_layers: Number of layers in the text encoder.
72
+ text_encoder_feedforward_dim: Dimension of the feedforward network in the
73
+ text encoder.
74
+ text_encoder_cnn_module_kernel: Kernel size for the CNN module in the
75
+ text encoder.
76
+ text_encoder_num_heads: Number of attention heads in the text encoder.
77
+ text_encoder_dim: Hidden dimension of the text encoder.
78
+ time_embed_dim: Dimension of the time embedding.
79
+ text_embed_dim: Dimension of the text embedding.
80
+ query_head_dim: Dimension of the query attention head.
81
+ value_head_dim: Dimension of the value attention head.
82
+ pos_head_dim: Dimension of the position attention head.
83
+ pos_dim: Dimension of the positional encoding.
84
+ feat_dim: Dimension of the acoustic features.
85
+ vocab_size: Size of the vocabulary.
86
+ pad_id: ID used for padding tokens.
87
+ spk_a_id: ID of speaker A / [S1].
88
+ spk_b_id: ID of speaker B / [S2].
89
+ """
90
+ super().__init__(
91
+ fm_decoder_downsampling_factor=fm_decoder_downsampling_factor,
92
+ fm_decoder_num_layers=fm_decoder_num_layers,
93
+ fm_decoder_cnn_module_kernel=fm_decoder_cnn_module_kernel,
94
+ fm_decoder_feedforward_dim=fm_decoder_feedforward_dim,
95
+ fm_decoder_num_heads=fm_decoder_num_heads,
96
+ fm_decoder_dim=fm_decoder_dim,
97
+ text_encoder_num_layers=text_encoder_num_layers,
98
+ text_encoder_feedforward_dim=text_encoder_feedforward_dim,
99
+ text_encoder_cnn_module_kernel=text_encoder_cnn_module_kernel,
100
+ text_encoder_num_heads=text_encoder_num_heads,
101
+ text_encoder_dim=text_encoder_dim,
102
+ time_embed_dim=time_embed_dim,
103
+ text_embed_dim=text_embed_dim,
104
+ query_head_dim=query_head_dim,
105
+ value_head_dim=value_head_dim,
106
+ pos_head_dim=pos_head_dim,
107
+ pos_dim=pos_dim,
108
+ feat_dim=feat_dim,
109
+ vocab_size=vocab_size,
110
+ pad_id=pad_id,
111
+ )
112
+
113
+ self.spk_a_id = spk_a_id
114
+ self.spk_b_id = spk_b_id
115
+ self.spk_embed = nn.Embedding(2, feat_dim)
116
+ torch.nn.init.normal_(self.spk_embed.weight, mean=0, std=0.1)
117
+
118
+ def extract_spk_indices(self, tensor):
119
+ turn_mask = ((tensor == self.spk_a_id) | (tensor == self.spk_b_id)).long()
120
+ turn_counts = turn_mask.cumsum(dim=1)
121
+ spk_mask = turn_counts % 2
122
+ spk_mask = torch.where(tensor == self.pad_id, -1, spk_mask)
123
+ spk_a_indices = torch.where(spk_mask == 0)
124
+ spk_b_indices = torch.where(spk_mask == 1)
125
+ return spk_a_indices, spk_b_indices
126
+
127
+ def forward_text_embed(
128
+ self,
129
+ tokens: List[List[int]],
130
+ ):
131
+ """
132
+ Get the text embeddings.
133
+ Args:
134
+ tokens: a list of list of token ids.
135
+ Returns:
136
+ embed: the text embeddings, shape (batch, seq_len, emb_dim).
137
+ tokens_lens: the length of each token sequence, shape (batch,).
138
+ """
139
+ device = (
140
+ self.device if isinstance(self, DDP) else next(self.parameters()).device
141
+ )
142
+ tokens_padded = pad_labels(tokens, pad_id=self.pad_id, device=device) # (B, S)
143
+ embed = self.embed(tokens_padded) # (B, S, C)
144
+ spk_a_indices, spk_b_indices = self.extract_spk_indices(tokens_padded)
145
+ tokens_lens = torch.tensor(
146
+ [len(token) for token in tokens], dtype=torch.int64, device=device
147
+ )
148
+ tokens_padding_mask = make_pad_mask(tokens_lens, embed.shape[1]) # (B, S)
149
+
150
+ embed = self.text_encoder(
151
+ x=embed, t=None, padding_mask=tokens_padding_mask
152
+ ) # (B, S, C)
153
+ embed[spk_a_indices] += self.spk_embed(torch.tensor(0, device=device)).to(
154
+ embed.dtype
155
+ )
156
+ embed[spk_b_indices] += self.spk_embed(torch.tensor(1, device=device)).to(
157
+ embed.dtype
158
+ )
159
+ return embed, tokens_lens
160
+
161
+ def forward(
162
+ self,
163
+ tokens: List[List[int]],
164
+ features: torch.Tensor,
165
+ features_lens: torch.Tensor,
166
+ noise: torch.Tensor,
167
+ t: torch.Tensor,
168
+ condition_drop_ratio: float = 0.0,
169
+ ) -> torch.Tensor:
170
+ """Forward pass of the model for training.
171
+ Args:
172
+ tokens: a list of list of token ids.
173
+ features: the acoustic features, with the shape (batch, seq_len, feat_dim).
174
+ features_lens: the length of each acoustic feature sequence, shape (batch,).
175
+ noise: the intitial noise, with the shape (batch, seq_len, feat_dim).
176
+ t: the time step, with the shape (batch, 1, 1).
177
+ condition_drop_ratio: the ratio of dropped text condition.
178
+ Returns:
179
+ fm_loss: the flow-matching loss.
180
+ """
181
+
182
+ (text_condition, padding_mask,) = self.forward_text_train(
183
+ tokens=tokens,
184
+ features_lens=features_lens,
185
+ )
186
+
187
+ speech_condition_mask = condition_time_mask_suffix(
188
+ features_lens=features_lens,
189
+ mask_percent=(0.5, 1.0),
190
+ max_len=features.size(1),
191
+ )
192
+ speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features)
193
+
194
+ if condition_drop_ratio > 0.0:
195
+ drop_mask = (
196
+ torch.rand(text_condition.size(0), 1, 1).to(text_condition.device)
197
+ > condition_drop_ratio
198
+ )
199
+ text_condition = text_condition * drop_mask
200
+
201
+ xt = features * t + noise * (1 - t)
202
+ ut = features - noise # (B, T, F)
203
+
204
+ vt = self.forward_fm_decoder(
205
+ t=t,
206
+ xt=xt,
207
+ text_condition=text_condition,
208
+ speech_condition=speech_condition,
209
+ padding_mask=padding_mask,
210
+ )
211
+
212
+ loss_mask = speech_condition_mask & (~padding_mask)
213
+ fm_loss = torch.mean((vt[loss_mask] - ut[loss_mask]) ** 2)
214
+
215
+ return fm_loss
216
+
217
+
218
+ class ZipVoiceDialogStereo(ZipVoiceDialog):
219
+ def __init__(self, *args, **kwargs):
220
+ super().__init__(*args, **kwargs)
221
+
222
+ required_params = {
223
+ "feat_dim",
224
+ "fm_decoder_downsampling_factor",
225
+ "fm_decoder_num_layers",
226
+ "fm_decoder_cnn_module_kernel",
227
+ "fm_decoder_dim",
228
+ "fm_decoder_feedforward_dim",
229
+ "fm_decoder_num_heads",
230
+ "query_head_dim",
231
+ "pos_head_dim",
232
+ "value_head_dim",
233
+ "pos_dim",
234
+ "time_embed_dim",
235
+ }
236
+
237
+ missing = [p for p in required_params if p not in kwargs]
238
+ if missing:
239
+ raise ValueError(f"Missing required parameters: {', '.join(missing)}")
240
+
241
+ self.fm_decoder = TTSZipformerTwoStream(
242
+ in_dim=(kwargs["feat_dim"] * 5, kwargs["feat_dim"] * 3),
243
+ out_dim=(kwargs["feat_dim"] * 2, kwargs["feat_dim"]),
244
+ downsampling_factor=kwargs["fm_decoder_downsampling_factor"],
245
+ num_encoder_layers=kwargs["fm_decoder_num_layers"],
246
+ cnn_module_kernel=kwargs["fm_decoder_cnn_module_kernel"],
247
+ encoder_dim=kwargs["fm_decoder_dim"],
248
+ feedforward_dim=kwargs["fm_decoder_feedforward_dim"],
249
+ num_heads=kwargs["fm_decoder_num_heads"],
250
+ query_head_dim=kwargs["query_head_dim"],
251
+ pos_head_dim=kwargs["pos_head_dim"],
252
+ value_head_dim=kwargs["value_head_dim"],
253
+ pos_dim=kwargs["pos_dim"],
254
+ use_time_embed=True,
255
+ time_embed_dim=kwargs["time_embed_dim"],
256
+ )
257
+
258
+ def forward(
259
+ self,
260
+ tokens: List[List[int]],
261
+ features: torch.Tensor,
262
+ features_lens: torch.Tensor,
263
+ noise: torch.Tensor,
264
+ t: torch.Tensor,
265
+ condition_drop_ratio: float = 0.0,
266
+ se_weight: float = 1.0,
267
+ ) -> torch.Tensor:
268
+ """Forward pass of the model for training.
269
+ Args:
270
+ tokens: a list of list of token ids.
271
+ features: the acoustic features, with the shape (batch, seq_len, feat_dim).
272
+ features_lens: the length of each acoustic feature sequence, shape (batch,).
273
+ noise: the intitial noise, with the shape (batch, seq_len, feat_dim).
274
+ t: the time step, with the shape (batch, 1, 1).
275
+ condition_drop_ratio: the ratio of dropped text condition.
276
+ se_weight: the weight of the speaker exclusive loss.
277
+ Returns:
278
+ fm_loss: the flow-matching loss.
279
+ """
280
+
281
+ (text_condition, padding_mask,) = self.forward_text_train(
282
+ tokens=tokens,
283
+ features_lens=features_lens,
284
+ )
285
+
286
+ speech_condition_mask = condition_time_mask_suffix(
287
+ features_lens=features_lens,
288
+ mask_percent=(0.5, 1.0),
289
+ max_len=features.size(1),
290
+ )
291
+ speech_condition = torch.where(speech_condition_mask.unsqueeze(-1), 0, features)
292
+
293
+ if condition_drop_ratio > 0.0:
294
+ drop_mask = (
295
+ torch.rand(text_condition.size(0), 1, 1).to(text_condition.device)
296
+ > condition_drop_ratio
297
+ )
298
+ text_condition = text_condition * drop_mask
299
+
300
+ xt = features * t + noise * (1 - t)
301
+ ut = features - noise # (B, T, F)
302
+
303
+ vt = self.forward_fm_decoder(
304
+ t=t,
305
+ xt=xt,
306
+ text_condition=text_condition,
307
+ speech_condition=speech_condition,
308
+ padding_mask=padding_mask,
309
+ )
310
+
311
+ loss_mask = speech_condition_mask & (~padding_mask)
312
+ fm_loss = torch.mean((vt[loss_mask] - ut[loss_mask]) ** 2)
313
+
314
+ if se_weight > 0:
315
+ target = xt + vt * (1 - t)
316
+ fbank_1 = target[:, :, : self.feat_dim]
317
+ fbank_2 = target[:, :, self.feat_dim :]
318
+ energy_loss = torch.mean(
319
+ self.energy_based_loss(fbank_1, fbank_2, features)[loss_mask]
320
+ )
321
+ loss = fm_loss + energy_loss * se_weight
322
+ else:
323
+ loss = fm_loss
324
+
325
+ return loss
326
+
327
+ def energy_based_loss(self, fbank1, fbank2, gt_fbank):
328
+ energy1 = self.energy(fbank1)
329
+ energy2 = self.energy(fbank2)
330
+
331
+ energy_thresholds = self.adaptive_threshold_from_gt(
332
+ torch.cat(
333
+ [
334
+ gt_fbank[:, :, : self.feat_dim],
335
+ gt_fbank[:, :, self.feat_dim :],
336
+ ],
337
+ dim=1,
338
+ )
339
+ )
340
+
341
+ both_speaking = (
342
+ (energy1 > energy_thresholds) & (energy2 > energy_thresholds)
343
+ ).float()
344
+
345
+ penalty = (
346
+ both_speaking
347
+ * (energy1 - energy_thresholds)
348
+ * (energy2 - energy_thresholds)
349
+ )
350
+ return penalty
351
+
352
+ def energy(self, fbank):
353
+ return torch.mean(fbank, dim=-1)
354
+
355
+ def adaptive_threshold_from_gt(self, gt_fbank, percentile=50):
356
+ frame_energies = self.energy(gt_fbank)
357
+ thresholds = torch.quantile(frame_energies, q=percentile / 100, dim=1)
358
+ return thresholds.unsqueeze(1)
zipvoice/models/zipvoice_distill.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Xiaomi Corp. (authors: Wei Kang
2
+ # Han Zhu)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from typing import List
19
+
20
+ import torch
21
+
22
+ from zipvoice.models.modules.solver import DistillEulerSolver
23
+ from zipvoice.models.modules.zipformer import TTSZipformer
24
+ from zipvoice.models.zipvoice import ZipVoice
25
+
26
+
27
+ class ZipVoiceDistill(ZipVoice):
28
+ """ZipVoice-Distill model."""
29
+
30
+ def __init__(self, *args, **kwargs):
31
+ super().__init__(*args, **kwargs)
32
+
33
+ required_params = {
34
+ "feat_dim",
35
+ "fm_decoder_downsampling_factor",
36
+ "fm_decoder_num_layers",
37
+ "fm_decoder_cnn_module_kernel",
38
+ "fm_decoder_dim",
39
+ "fm_decoder_feedforward_dim",
40
+ "fm_decoder_num_heads",
41
+ "query_head_dim",
42
+ "pos_head_dim",
43
+ "value_head_dim",
44
+ "pos_dim",
45
+ "time_embed_dim",
46
+ }
47
+
48
+ missing = [p for p in required_params if p not in kwargs]
49
+ if missing:
50
+ raise ValueError(f"Missing required parameters: {', '.join(missing)}")
51
+
52
+ self.fm_decoder = TTSZipformer(
53
+ in_dim=kwargs["feat_dim"] * 3,
54
+ out_dim=kwargs["feat_dim"],
55
+ downsampling_factor=kwargs["fm_decoder_downsampling_factor"],
56
+ num_encoder_layers=kwargs["fm_decoder_num_layers"],
57
+ cnn_module_kernel=kwargs["fm_decoder_cnn_module_kernel"],
58
+ encoder_dim=kwargs["fm_decoder_dim"],
59
+ feedforward_dim=kwargs["fm_decoder_feedforward_dim"],
60
+ num_heads=kwargs["fm_decoder_num_heads"],
61
+ query_head_dim=kwargs["query_head_dim"],
62
+ pos_head_dim=kwargs["pos_head_dim"],
63
+ value_head_dim=kwargs["value_head_dim"],
64
+ pos_dim=kwargs["pos_dim"],
65
+ use_time_embed=True,
66
+ time_embed_dim=kwargs["time_embed_dim"],
67
+ use_guidance_scale_embed=True,
68
+ )
69
+ self.solver = DistillEulerSolver(self, func_name="forward_fm_decoder")
70
+
71
+ def forward(
72
+ self,
73
+ tokens: List[List[int]],
74
+ features: torch.Tensor,
75
+ features_lens: torch.Tensor,
76
+ noise: torch.Tensor,
77
+ speech_condition_mask: torch.Tensor,
78
+ t_start: float,
79
+ t_end: float,
80
+ num_step: int = 1,
81
+ guidance_scale: torch.Tensor = None,
82
+ ) -> torch.Tensor:
83
+
84
+ return self.sample_intermediate(
85
+ tokens=tokens,
86
+ features=features,
87
+ features_lens=features_lens,
88
+ noise=noise,
89
+ speech_condition_mask=speech_condition_mask,
90
+ t_start=t_start,
91
+ t_end=t_end,
92
+ num_step=num_step,
93
+ guidance_scale=guidance_scale,
94
+ )
zipvoice/tokenizer/normalizer.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from abc import ABC, abstractmethod
3
+
4
+ import cn2an
5
+ import inflect
6
+
7
+
8
+ class TextNormalizer(ABC):
9
+ """Abstract base class for text normalization, defining common interface."""
10
+
11
+ @abstractmethod
12
+ def normalize(self, text: str) -> str:
13
+ """Normalize text."""
14
+ raise NotImplementedError
15
+
16
+
17
+ class EnglishTextNormalizer(TextNormalizer):
18
+ """
19
+ A class to handle preprocessing of English text including normalization. Following:
20
+ https://github.com/espnet/espnet_tts_frontend/blob/master/tacotron_cleaner/cleaners.py
21
+ """
22
+
23
+ def __init__(self):
24
+ # List of (regular expression, replacement) pairs for abbreviations:
25
+ self._abbreviations = [
26
+ (re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1])
27
+ for x in [
28
+ ("mrs", "misess"),
29
+ ("mr", "mister"),
30
+ ("dr", "doctor"),
31
+ ("st", "saint"),
32
+ ("co", "company"),
33
+ ("jr", "junior"),
34
+ ("maj", "major"),
35
+ ("gen", "general"),
36
+ ("drs", "doctors"),
37
+ ("rev", "reverend"),
38
+ ("lt", "lieutenant"),
39
+ ("hon", "honorable"),
40
+ ("sgt", "sergeant"),
41
+ ("capt", "captain"),
42
+ ("esq", "esquire"),
43
+ ("ltd", "limited"),
44
+ ("col", "colonel"),
45
+ ("ft", "fort"),
46
+ ("etc", "et cetera"),
47
+ ("btw", "by the way"),
48
+ ]
49
+ ]
50
+
51
+ self._inflect = inflect.engine()
52
+ self._comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
53
+ self._decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
54
+ self._percent_number_re = re.compile(r"([0-9\.\,]*[0-9]+%)")
55
+ self._pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
56
+ self._dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
57
+ self._fraction_re = re.compile(r"([0-9]+)/([0-9]+)")
58
+ self._ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
59
+ self._number_re = re.compile(r"[0-9]+")
60
+ self._whitespace_re = re.compile(r"\s+")
61
+
62
+ def normalize(self, text: str) -> str:
63
+ """Custom pipeline for English text,
64
+ including number and abbreviation expansion."""
65
+ text = self.expand_abbreviations(text)
66
+ text = self.normalize_numbers(text)
67
+
68
+ return text
69
+
70
+ def fraction_to_words(self, numerator, denominator):
71
+ if numerator == 1 and denominator == 2:
72
+ return " one half "
73
+ if numerator == 1 and denominator == 4:
74
+ return " one quarter "
75
+ if denominator == 2:
76
+ return " " + self._inflect.number_to_words(numerator) + " halves "
77
+ if denominator == 4:
78
+ return " " + self._inflect.number_to_words(numerator) + " quarters "
79
+ return (
80
+ " "
81
+ + self._inflect.number_to_words(numerator)
82
+ + " "
83
+ + self._inflect.ordinal(self._inflect.number_to_words(denominator))
84
+ + " "
85
+ )
86
+
87
+ def _remove_commas(self, m):
88
+ return m.group(1).replace(",", "")
89
+
90
+ def _expand_dollars(self, m):
91
+ match = m.group(1)
92
+ parts = match.split(".")
93
+ if len(parts) > 2:
94
+ return " " + match + " dollars " # Unexpected format
95
+ dollars = int(parts[0]) if parts[0] else 0
96
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
97
+ if dollars and cents:
98
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
99
+ cent_unit = "cent" if cents == 1 else "cents"
100
+ return " %s %s, %s %s " % (dollars, dollar_unit, cents, cent_unit)
101
+ elif dollars:
102
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
103
+ return " %s %s " % (dollars, dollar_unit)
104
+ elif cents:
105
+ cent_unit = "cent" if cents == 1 else "cents"
106
+ return " %s %s " % (cents, cent_unit)
107
+ else:
108
+ return " zero dollars "
109
+
110
+ def _expand_fraction(self, m):
111
+ numerator = int(m.group(1))
112
+ denominator = int(m.group(2))
113
+ return self.fraction_to_words(numerator, denominator)
114
+
115
+ def _expand_decimal_point(self, m):
116
+ return m.group(1).replace(".", " point ")
117
+
118
+ def _expand_percent(self, m):
119
+ return m.group(1).replace("%", " percent ")
120
+
121
+ def _expand_ordinal(self, m):
122
+ return " " + self._inflect.number_to_words(m.group(0)) + " "
123
+
124
+ def _expand_number(self, m):
125
+ num = int(m.group(0))
126
+ if num > 1000 and num < 3000:
127
+ if num == 2000:
128
+ return " two thousand "
129
+ elif num > 2000 and num < 2010:
130
+ return " two thousand " + self._inflect.number_to_words(num % 100) + " "
131
+ elif num % 100 == 0:
132
+ return " " + self._inflect.number_to_words(num // 100) + " hundred "
133
+ else:
134
+ return (
135
+ " "
136
+ + self._inflect.number_to_words(
137
+ num, andword="", zero="oh", group=2
138
+ ).replace(", ", " ")
139
+ + " "
140
+ )
141
+ else:
142
+ return " " + self._inflect.number_to_words(num, andword="") + " "
143
+
144
+ def normalize_numbers(self, text):
145
+ text = re.sub(self._comma_number_re, self._remove_commas, text)
146
+ text = re.sub(self._pounds_re, r"\1 pounds", text)
147
+ text = re.sub(self._dollars_re, self._expand_dollars, text)
148
+ text = re.sub(self._fraction_re, self._expand_fraction, text)
149
+ text = re.sub(self._decimal_number_re, self._expand_decimal_point, text)
150
+ text = re.sub(self._percent_number_re, self._expand_percent, text)
151
+ text = re.sub(self._ordinal_re, self._expand_ordinal, text)
152
+ text = re.sub(self._number_re, self._expand_number, text)
153
+ return text
154
+
155
+ def expand_abbreviations(self, text):
156
+ for regex, replacement in self._abbreviations:
157
+ text = re.sub(regex, replacement, text)
158
+ return text
159
+
160
+
161
+ class ChineseTextNormalizer(TextNormalizer):
162
+ """
163
+ A class to handle preprocessing of Chinese text including normalization.
164
+ """
165
+
166
+ def normalize(self, text: str) -> str:
167
+ """Normalize text."""
168
+ # Convert numbers to Chinese
169
+ text = cn2an.transform(text, "an2cn")
170
+ return text
zipvoice/tokenizer/tokenizer.py ADDED
@@ -0,0 +1,618 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-2024 Xiaomi Corp. (authors: Zengwei Yao
2
+ # Han Zhu,
3
+ # Wei Kang)
4
+ #
5
+ # See ../../LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ import logging
20
+ import re
21
+ from abc import ABC, abstractmethod
22
+ from functools import reduce
23
+ from typing import Dict, List, Optional
24
+
25
+ import jieba
26
+ from pypinyin import Style, lazy_pinyin
27
+ from pypinyin.contrib.tone_convert import to_finals_tone3, to_initials
28
+
29
+ from zipvoice.tokenizer.normalizer import ChineseTextNormalizer, EnglishTextNormalizer
30
+
31
+ try:
32
+ from piper_phonemize import phonemize_espeak
33
+ except Exception as ex:
34
+ raise RuntimeError(
35
+ f"{ex}\nPlease run\n"
36
+ "pip install piper_phonemize -f \
37
+ https://k2-fsa.github.io/icefall/piper_phonemize.html"
38
+ )
39
+
40
+
41
+ class Tokenizer(ABC):
42
+ """Abstract base class for tokenizers, defining common interface."""
43
+
44
+ @abstractmethod
45
+ def texts_to_token_ids(self, texts: List[str]) -> List[List[int]]:
46
+ """Convert list of texts to list of token id sequences."""
47
+ raise NotImplementedError
48
+
49
+ @abstractmethod
50
+ def texts_to_tokens(self, texts: List[str]) -> List[List[str]]:
51
+ """Convert list of texts to list of token sequences."""
52
+ raise NotImplementedError
53
+
54
+ @abstractmethod
55
+ def tokens_to_token_ids(self, tokens: List[List[str]]) -> List[List[int]]:
56
+ """Convert list of token sequences to list of token id sequences."""
57
+ raise NotImplementedError
58
+
59
+
60
+ class SimpleTokenizer(Tokenizer):
61
+ """The simplpest tokenizer, treat every character as a token,
62
+ without text normalization.
63
+ """
64
+
65
+ def __init__(self, token_file: Optional[str] = None):
66
+ """
67
+ Args:
68
+ tokens: the file that contains information that maps tokens to ids,
69
+ which is a text file with '{token}\t{token_id}' per line.
70
+ """
71
+ # Parse token file
72
+ self.has_tokens = False
73
+ if token_file is None:
74
+ logging.debug(
75
+ "Initialize Tokenizer without tokens file, \
76
+ will fail when map to ids."
77
+ )
78
+ return
79
+ self.token2id: Dict[str, int] = {}
80
+ with open(token_file, "r", encoding="utf-8") as f:
81
+ for line in f.readlines():
82
+ info = line.rstrip().split("\t")
83
+ token, id = info[0], int(info[1])
84
+ assert token not in self.token2id, token
85
+ self.token2id[token] = id
86
+ self.pad_id = self.token2id["_"] # padding
87
+ self.vocab_size = len(self.token2id)
88
+ self.has_tokens = True
89
+
90
+ def texts_to_token_ids(
91
+ self,
92
+ texts: List[str],
93
+ ) -> List[List[int]]:
94
+ return self.tokens_to_token_ids(self.texts_to_tokens(texts))
95
+
96
+ def texts_to_tokens(
97
+ self,
98
+ texts: List[str],
99
+ ) -> List[List[str]]:
100
+ tokens_list = [list(texts[i]) for i in range(len(texts))]
101
+ return tokens_list
102
+
103
+ def tokens_to_token_ids(
104
+ self,
105
+ tokens_list: List[List[str]],
106
+ ) -> List[List[int]]:
107
+ assert self.has_tokens, "Please initialize Tokenizer with a tokens file."
108
+
109
+ token_ids_list = []
110
+
111
+ for tokens in tokens_list:
112
+ token_ids = []
113
+ for t in tokens:
114
+ if t not in self.token2id:
115
+ logging.debug(f"Skip OOV {t}")
116
+ continue
117
+ token_ids.append(self.token2id[t])
118
+
119
+ token_ids_list.append(token_ids)
120
+
121
+ return token_ids_list
122
+
123
+
124
+ class EspeakTokenizer(Tokenizer):
125
+ """A simple tokenizer with Espeak g2p function."""
126
+
127
+ def __init__(self, token_file: Optional[str] = None, lang: str = "en-us"):
128
+ """
129
+ Args:
130
+ tokens: the file that contains information that maps tokens to ids,
131
+ which is a text file with '{token}\t{token_id}' per line.
132
+ lang: the language identifier, see
133
+ https://github.com/rhasspy/espeak-ng/blob/master/docs/languages.md
134
+ """
135
+ # Parse token file
136
+ self.has_tokens = False
137
+ if token_file is None:
138
+ logging.debug(
139
+ "Initialize Tokenizer without tokens file, \
140
+ will fail when map to ids."
141
+ )
142
+ return
143
+ self.token2id: Dict[str, int] = {}
144
+ with open(token_file, "r", encoding="utf-8") as f:
145
+ for line in f.readlines():
146
+ info = line.rstrip().split("\t")
147
+ token, id = info[0], int(info[1])
148
+ assert token not in self.token2id, token
149
+ self.token2id[token] = id
150
+ self.pad_id = self.token2id["_"] # padding
151
+ self.vocab_size = len(self.token2id)
152
+ self.has_tokens = True
153
+ self.lang = lang
154
+
155
+ def g2p(self, text: str) -> List[str]:
156
+ try:
157
+ tokens = phonemize_espeak(text, self.lang)
158
+ tokens = reduce(lambda x, y: x + y, tokens)
159
+ return tokens
160
+ except Exception as ex:
161
+ logging.warning(f"Tokenization of {self.lang} texts failed: {ex}")
162
+ return []
163
+
164
+ def texts_to_token_ids(
165
+ self,
166
+ texts: List[str],
167
+ ) -> List[List[int]]:
168
+ return self.tokens_to_token_ids(self.texts_to_tokens(texts))
169
+
170
+ def texts_to_tokens(
171
+ self,
172
+ texts: List[str],
173
+ ) -> List[List[str]]:
174
+ tokens_list = [self.g2p(texts[i]) for i in range(len(texts))]
175
+ return tokens_list
176
+
177
+ def tokens_to_token_ids(
178
+ self,
179
+ tokens_list: List[List[str]],
180
+ ) -> List[List[int]]:
181
+ assert self.has_tokens, "Please initialize Tokenizer with a tokens file."
182
+
183
+ token_ids_list = []
184
+
185
+ for tokens in tokens_list:
186
+ token_ids = []
187
+ for t in tokens:
188
+ if t not in self.token2id:
189
+ logging.debug(f"Skip OOV {t}")
190
+ continue
191
+ token_ids.append(self.token2id[t])
192
+
193
+ token_ids_list.append(token_ids)
194
+
195
+ return token_ids_list
196
+
197
+
198
+ class EmiliaTokenizer(Tokenizer):
199
+ def __init__(self, token_file: Optional[str] = None, token_type="phone"):
200
+ """
201
+ Args:
202
+ tokens: the file that contains information that maps tokens to ids,
203
+ which is a text file with '{token}\t{token_id}' per line.
204
+ """
205
+ assert (
206
+ token_type == "phone"
207
+ ), f"Only support phone tokenizer for Emilia, but get {token_type}."
208
+
209
+ self.english_normalizer = EnglishTextNormalizer()
210
+ self.chinese_normalizer = ChineseTextNormalizer()
211
+
212
+ self.has_tokens = False
213
+ if token_file is None:
214
+ logging.debug(
215
+ "Initialize Tokenizer without tokens file, \
216
+ will fail when map to ids."
217
+ )
218
+ return
219
+ self.token2id: Dict[str, int] = {}
220
+ with open(token_file, "r", encoding="utf-8") as f:
221
+ for line in f.readlines():
222
+ info = line.rstrip().split("\t")
223
+ token, id = info[0], int(info[1])
224
+ assert token not in self.token2id, token
225
+ self.token2id[token] = id
226
+ self.pad_id = self.token2id["_"] # padding
227
+
228
+ self.vocab_size = len(self.token2id)
229
+ self.has_tokens = True
230
+
231
+ def texts_to_token_ids(
232
+ self,
233
+ texts: List[str],
234
+ ) -> List[List[int]]:
235
+ return self.tokens_to_token_ids(self.texts_to_tokens(texts))
236
+
237
+ def preprocess_text(
238
+ self,
239
+ text: str,
240
+ ) -> str:
241
+ return self.map_punctuations(text)
242
+
243
+ def texts_to_tokens(
244
+ self,
245
+ texts: List[str],
246
+ ) -> List[List[str]]:
247
+ for i in range(len(texts)):
248
+ # Text normalization
249
+ texts[i] = self.preprocess_text(texts[i])
250
+
251
+ phoneme_list = []
252
+ for text in texts:
253
+ # now only en and ch
254
+ segments = self.get_segment(text)
255
+ all_phoneme = []
256
+ for index in range(len(segments)):
257
+ seg = segments[index]
258
+ if seg[1] == "zh":
259
+ phoneme = self.tokenize_ZH(seg[0])
260
+ elif seg[1] == "en":
261
+ phoneme = self.tokenize_EN(seg[0])
262
+ elif seg[1] == "pinyin":
263
+ phoneme = self.tokenize_pinyin(seg[0])
264
+ elif seg[1] == "tag":
265
+ phoneme = [seg[0]]
266
+ else:
267
+ logging.warning(
268
+ f"No English or Chinese characters found, \
269
+ skipping segment of unknown language: {seg}"
270
+ )
271
+ continue
272
+ all_phoneme += phoneme
273
+ phoneme_list.append(all_phoneme)
274
+ return phoneme_list
275
+
276
+ def tokens_to_token_ids(
277
+ self,
278
+ tokens_list: List[List[str]],
279
+ ) -> List[List[int]]:
280
+ assert self.has_tokens, "Please initialize Tokenizer with a tokens file."
281
+ token_ids_list = []
282
+
283
+ for tokens in tokens_list:
284
+ token_ids = []
285
+ for t in tokens:
286
+ if t not in self.token2id:
287
+ logging.debug(f"Skip OOV {t}")
288
+ continue
289
+ token_ids.append(self.token2id[t])
290
+
291
+ token_ids_list.append(token_ids)
292
+
293
+ return token_ids_list
294
+
295
+ def tokenize_ZH(self, text: str) -> List[str]:
296
+ try:
297
+ text = self.chinese_normalizer.normalize(text)
298
+ segs = list(jieba.cut(text))
299
+ full = lazy_pinyin(
300
+ segs,
301
+ style=Style.TONE3,
302
+ tone_sandhi=True,
303
+ neutral_tone_with_five=True,
304
+ )
305
+ phones = []
306
+ for x in full:
307
+ # valid pinyin (in tone3 style) is alphabet + 1 number in [1-5].
308
+ if not (x[0:-1].isalpha() and x[-1] in ("1", "2", "3", "4", "5")):
309
+ phones.append(x)
310
+ continue
311
+ else:
312
+ phones.extend(self.seperate_pinyin(x))
313
+ return phones
314
+ except Exception as ex:
315
+ logging.warning(f"Tokenization of Chinese texts failed: {ex}")
316
+ return []
317
+
318
+ def tokenize_EN(self, text: str) -> List[str]:
319
+ try:
320
+ text = self.english_normalizer.normalize(text)
321
+ tokens = phonemize_espeak(text, "en-us")
322
+ tokens = reduce(lambda x, y: x + y, tokens)
323
+ return tokens
324
+ except Exception as ex:
325
+ logging.warning(f"Tokenization of English texts failed: {ex}")
326
+ return []
327
+
328
+ def tokenize_pinyin(self, text: str) -> List[str]:
329
+ try:
330
+ assert text.startswith("<") and text.endswith(">")
331
+ text = text.lstrip("<").rstrip(">")
332
+ # valid pinyin (in tone3 style) is alphabet + 1 number in [1-5].
333
+ if not (text[0:-1].isalpha() and text[-1] in ("1", "2", "3", "4", "5")):
334
+ logging.warning(
335
+ f"Strings enclosed with <> should be pinyin, \
336
+ but got: {text}. Skipped it. "
337
+ )
338
+ return []
339
+ else:
340
+ return self.seperate_pinyin(text)
341
+ except Exception as ex:
342
+ logging.warning(f"Tokenize pinyin failed: {ex}")
343
+ return []
344
+
345
+ def seperate_pinyin(self, text: str) -> List[str]:
346
+ """
347
+ Separate pinyin into initial and final
348
+ """
349
+ pinyins = []
350
+ initial = to_initials(text, strict=False)
351
+ # don't want to share tokens with espeak tokens,
352
+ # so use tone3 style
353
+ final = to_finals_tone3(
354
+ text,
355
+ strict=False,
356
+ neutral_tone_with_five=True,
357
+ )
358
+ if initial != "":
359
+ # don't want to share tokens with espeak tokens,
360
+ # so add a '0' after each initial
361
+ pinyins.append(initial + "0")
362
+ if final != "":
363
+ pinyins.append(final)
364
+ return pinyins
365
+
366
+ def map_punctuations(self, text):
367
+ text = text.replace(",", ",")
368
+ text = text.replace("。", ".")
369
+ text = text.replace("!", "!")
370
+ text = text.replace("?", "?")
371
+ text = text.replace(";", ";")
372
+ text = text.replace(":", ":")
373
+ text = text.replace("、", ",")
374
+ text = text.replace("‘", "'")
375
+ text = text.replace("“", '"')
376
+ text = text.replace("”", '"')
377
+ text = text.replace("’", "'")
378
+ text = text.replace("⋯", "…")
379
+ text = text.replace("···", "…")
380
+ text = text.replace("・・・", "…")
381
+ text = text.replace("...", "…")
382
+ return text
383
+
384
+ def get_segment(self, text: str) -> List[str]:
385
+ """
386
+ Split a text into segments based on language types
387
+ (Chinese, English, Pinyin, tags, etc.)
388
+
389
+ Args:
390
+ text (str): Input text to be segmented
391
+
392
+ Returns:
393
+ List[str]: Segmented text parts with their language types
394
+
395
+ Example:
396
+ Input: 我们是小米人,是吗? Yes I think so!霍...啦啦啦
397
+ Output: [('我们是小米人,是吗? ', 'zh'),
398
+ ('Yes I think so!', 'en'), ('霍...啦啦啦', 'zh')]
399
+ """
400
+ # Stores the final segmented parts and their language types
401
+ segments = []
402
+ # Stores the language type of each character in the input text
403
+ types = []
404
+ temp_seg = ""
405
+ temp_lang = ""
406
+
407
+ # Each part is a character, or a special string enclosed in <> and []
408
+ # <> denotes pinyin string, [] denotes other special strings.
409
+ _part_pattern = re.compile(r"[<[].*?[>\]]|.")
410
+ text = _part_pattern.findall(text)
411
+
412
+ for i, part in enumerate(text):
413
+ if self.is_chinese(part) or self.is_pinyin(part):
414
+ types.append("zh")
415
+ elif self.is_alphabet(part):
416
+ types.append("en")
417
+ else:
418
+ types.append("other")
419
+
420
+ assert len(types) == len(text)
421
+
422
+ for i in range(len(types)):
423
+ # find the first char of the seg
424
+ if i == 0:
425
+ temp_seg += text[i]
426
+ temp_lang = types[i]
427
+ else:
428
+ if temp_lang == "other":
429
+ temp_seg += text[i]
430
+ temp_lang = types[i]
431
+ else:
432
+ if types[i] in [temp_lang, "other"]:
433
+ temp_seg += text[i]
434
+ else:
435
+ segments.append((temp_seg, temp_lang))
436
+ temp_seg = text[i]
437
+ temp_lang = types[i]
438
+
439
+ segments.append((temp_seg, temp_lang))
440
+
441
+ # Handle "pinyin" and "tag" types
442
+ segments = self.split_segments(segments)
443
+ return segments
444
+
445
+ def split_segments(self, segments):
446
+ """
447
+ split segments into smaller parts if special strings enclosed by [] or <>
448
+ are found, where <> denotes pinyin strings, [] denotes other special strings.
449
+
450
+ Args:
451
+ segments (list): A list of tuples where each tuple contains:
452
+ - temp_seg (str): The text segment to be split.
453
+ - temp_lang (str): The language code associated with the segment.
454
+
455
+ Returns:
456
+ list: A list of smaller segments.
457
+ """
458
+ result = []
459
+ for temp_seg, temp_lang in segments:
460
+ parts = re.split(r"([<[].*?[>\]])", temp_seg)
461
+ for part in parts:
462
+ if not part:
463
+ continue
464
+ if self.is_pinyin(part):
465
+ result.append((part, "pinyin"))
466
+ elif self.is_tag(part):
467
+ result.append((part, "tag"))
468
+ else:
469
+ result.append((part, temp_lang))
470
+ return result
471
+
472
+ def is_chinese(self, char: str) -> bool:
473
+ if char >= "\u4e00" and char <= "\u9fa5":
474
+ return True
475
+ else:
476
+ return False
477
+
478
+ def is_alphabet(self, char: str) -> bool:
479
+ if (char >= "\u0041" and char <= "\u005a") or (
480
+ char >= "\u0061" and char <= "\u007a"
481
+ ):
482
+ return True
483
+ else:
484
+ return False
485
+
486
+ def is_pinyin(self, part: str) -> bool:
487
+ if part.startswith("<") and part.endswith(">"):
488
+ return True
489
+ else:
490
+ return False
491
+
492
+ def is_tag(self, part: str) -> bool:
493
+ if part.startswith("[") and part.endswith("]"):
494
+ return True
495
+ else:
496
+ return False
497
+
498
+
499
+ class DialogTokenizer(EmiliaTokenizer):
500
+ def __init__(self, token_file: Optional[str] = None, token_type="phone"):
501
+ super().__init__(token_file=token_file, token_type=token_type)
502
+ self.spk_a_id = self.token2id["[S1]"]
503
+ self.spk_b_id = self.token2id["[S2]"]
504
+
505
+ def preprocess_text(
506
+ self,
507
+ text: str,
508
+ ) -> str:
509
+ text = re.sub(r"\s*(\[S[12]\])\s*", r"\1", text)
510
+ text = self.map_punctuations(text)
511
+ return text
512
+
513
+
514
+ class LibriTTSTokenizer(Tokenizer):
515
+ def __init__(self, token_file: Optional[str] = None, token_type="char"):
516
+ """
517
+ Args:
518
+ type: the type of tokenizer, e.g., bpe, char, phone.
519
+ tokens: the file that contains information that maps tokens to ids,
520
+ which is a text file with '{token}\t{token_id}' per line if type is
521
+ char or phone, otherwise it is a bpe_model file.
522
+ """
523
+ self.type = token_type
524
+ assert token_type in ["bpe", "char", "phone"]
525
+ try:
526
+ import tacotron_cleaner.cleaners
527
+ except Exception as ex:
528
+ raise RuntimeError(f"{ex}\nPlease run\n" "pip install espnet_tts_frontend")
529
+
530
+ self.normalize = tacotron_cleaner.cleaners.custom_english_cleaners
531
+
532
+ self.has_tokens = False
533
+ if token_file is None:
534
+ logging.debug(
535
+ "Initialize Tokenizer without tokens file, \
536
+ will fail when map to ids."
537
+ )
538
+ return
539
+ if token_type == "bpe":
540
+ import sentencepiece as spm
541
+
542
+ self.sp = spm.SentencePieceProcessor()
543
+ self.sp.load(token_file)
544
+ self.pad_id = self.sp.piece_to_id("<pad>")
545
+ self.vocab_size = self.sp.get_piece_size()
546
+ else:
547
+ self.token2id: Dict[str, int] = {}
548
+ with open(token_file, "r", encoding="utf-8") as f:
549
+ for line in f.readlines():
550
+ info = line.rstrip().split("\t")
551
+ token, id = info[0], int(info[1])
552
+ assert token not in self.token2id, token
553
+ self.token2id[token] = id
554
+ self.pad_id = self.token2id["_"] # padding
555
+ self.vocab_size = len(self.token2id)
556
+ self.has_tokens = True
557
+
558
+ def texts_to_token_ids(
559
+ self,
560
+ texts: List[str],
561
+ ) -> List[List[int]]:
562
+ if self.type == "bpe":
563
+ for i in range(len(texts)):
564
+ texts[i] = self.normalize(texts[i])
565
+ return self.sp.encode(texts)
566
+ else:
567
+ return self.tokens_to_token_ids(self.texts_to_tokens(texts))
568
+
569
+ def texts_to_tokens(
570
+ self,
571
+ texts: List[str],
572
+ ) -> List[List[str]]:
573
+ for i in range(len(texts)):
574
+ texts[i] = self.normalize(texts[i])
575
+
576
+ if self.type == "char":
577
+ tokens_list = [list(texts[i]) for i in range(len(texts))]
578
+ elif self.type == "phone":
579
+ tokens_list = [
580
+ phonemize_espeak(texts[i].lower(), "en-us") for i in range(len(texts))
581
+ ]
582
+ elif self.type == "bpe":
583
+ tokens_list = self.sp.encode(texts, out_type=str)
584
+
585
+ return tokens_list
586
+
587
+ def tokens_to_token_ids(
588
+ self,
589
+ tokens_list: List[List[str]],
590
+ ) -> List[List[int]]:
591
+ assert self.has_tokens, "Please initialize Tokenizer with a tokens file."
592
+
593
+ assert self.type != "bpe", "BPE tokenizer does not support this function."
594
+
595
+ token_ids_list = []
596
+
597
+ for tokens in tokens_list:
598
+ token_ids = []
599
+ for t in tokens:
600
+ if t not in self.token2id:
601
+ logging.debug(f"Skip OOV {t}")
602
+ continue
603
+ token_ids.append(self.token2id[t])
604
+
605
+ token_ids_list.append(token_ids)
606
+
607
+ return token_ids_list
608
+
609
+
610
+ if __name__ == "__main__":
611
+ text = (
612
+ "我们是5年小米人,是吗? Yes I think so! "
613
+ "mr king, 5 years, from 2019 to 2024."
614
+ "霍...啦啦啦超过90%的人<le5>...?!9204"
615
+ )
616
+ tokenizer = EmiliaTokenizer()
617
+ tokens = tokenizer.texts_to_tokens([text])
618
+ print(f"tokens: {'|'.join(tokens[0])}")
zipvoice/utils/checkpoint.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021-2025 Xiaomi Corporation (authors: Fangjun Kuang,
2
+ # Zengwei Yao)
3
+ #
4
+ # See ../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import glob
19
+ import logging
20
+ import os
21
+ import re
22
+ from pathlib import Path
23
+ from typing import Any, Dict, List, Optional, Union
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ from lhotse.dataset.sampling.base import CutSampler
28
+ from torch.cuda.amp import GradScaler
29
+ from torch.nn.parallel import DistributedDataParallel as DDP
30
+ from torch.optim import Optimizer
31
+
32
+ from zipvoice.utils.common import AttributeDict
33
+
34
+ # use duck typing for LRScheduler since we have different possibilities, see
35
+ # our class LRScheduler.
36
+ LRSchedulerType = object
37
+
38
+
39
+ def save_checkpoint(
40
+ filename: Path,
41
+ model: Union[nn.Module, DDP],
42
+ model_avg: Optional[nn.Module] = None,
43
+ model_ema: Optional[nn.Module] = None,
44
+ params: Optional[Dict[str, Any]] = None,
45
+ optimizer: Optional[Optimizer] = None,
46
+ scheduler: Optional[LRSchedulerType] = None,
47
+ scaler: Optional[GradScaler] = None,
48
+ sampler: Optional[CutSampler] = None,
49
+ rank: int = 0,
50
+ ) -> None:
51
+ """Save training information to a file.
52
+
53
+ Args:
54
+ filename:
55
+ The checkpoint filename.
56
+ model:
57
+ The model to be saved. We only save its `state_dict()`.
58
+ model_avg:
59
+ The stored model averaged from the start of training.
60
+ model_ema:
61
+ The EMA version of model.
62
+ params:
63
+ User defined parameters, e.g., epoch, loss.
64
+ optimizer:
65
+ The optimizer to be saved. We only save its `state_dict()`.
66
+ scheduler:
67
+ The scheduler to be saved. We only save its `state_dict()`.
68
+ scalar:
69
+ The GradScaler to be saved. We only save its `state_dict()`.
70
+ sampler:
71
+ The sampler used in the labeled training dataset. We only
72
+ save its `state_dict()`.
73
+ rank:
74
+ Used in DDP. We save checkpoint only for the node whose
75
+ rank is 0.
76
+ Returns:
77
+ Return None.
78
+ """
79
+ if rank != 0:
80
+ return
81
+
82
+ logging.info(f"Saving checkpoint to {filename}")
83
+
84
+ if isinstance(model, DDP):
85
+ model = model.module
86
+
87
+ checkpoint = {
88
+ "model": model.state_dict(),
89
+ "optimizer": optimizer.state_dict() if optimizer is not None else None,
90
+ "scheduler": scheduler.state_dict() if scheduler is not None else None,
91
+ "grad_scaler": scaler.state_dict() if scaler is not None else None,
92
+ "sampler": sampler.state_dict() if sampler is not None else None,
93
+ }
94
+
95
+ if model_avg is not None:
96
+ checkpoint["model_avg"] = model_avg.to(torch.float32).state_dict()
97
+ if model_ema is not None:
98
+ checkpoint["model_ema"] = model_ema.to(torch.float32).state_dict()
99
+
100
+ if params:
101
+ for k, v in params.items():
102
+ assert k not in checkpoint
103
+ checkpoint[k] = v
104
+
105
+ torch.save(checkpoint, filename)
106
+
107
+
108
+ def load_checkpoint(
109
+ filename: Path,
110
+ model: Optional[nn.Module] = None,
111
+ model_avg: Optional[nn.Module] = None,
112
+ model_ema: Optional[nn.Module] = None,
113
+ strict: bool = False,
114
+ ) -> Dict[str, Any]:
115
+ logging.info(f"Loading checkpoint from {filename}")
116
+ checkpoint = torch.load(filename, map_location="cpu", weights_only=False)
117
+
118
+ if model is not None:
119
+
120
+ if next(iter(checkpoint["model"])).startswith("module."):
121
+ logging.info("Loading checkpoint saved by DDP")
122
+
123
+ dst_state_dict = model.state_dict()
124
+ src_state_dict = checkpoint["model"]
125
+ for key in dst_state_dict.keys():
126
+ src_key = "{}.{}".format("module", key)
127
+ dst_state_dict[key] = src_state_dict.pop(src_key)
128
+ assert len(src_state_dict) == 0
129
+ model.load_state_dict(dst_state_dict, strict=strict)
130
+ else:
131
+ logging.info("Loading checkpoint")
132
+ model.load_state_dict(checkpoint["model"], strict=strict)
133
+
134
+ checkpoint.pop("model")
135
+
136
+ if model_avg is not None and "model_avg" in checkpoint:
137
+ logging.info("Loading averaged model")
138
+ model_avg.load_state_dict(checkpoint["model_avg"], strict=strict)
139
+ checkpoint.pop("model_avg")
140
+
141
+ if model_ema is not None and "model_ema" in checkpoint:
142
+ logging.info("Loading ema model")
143
+ model_ema.load_state_dict(checkpoint["model_ema"], strict=strict)
144
+ checkpoint.pop("model_ema")
145
+
146
+ return checkpoint
147
+
148
+
149
+ def load_checkpoint_extend_vocab_size(
150
+ filename: Path, extend_size: int, model: nn.Module, strict: bool = True
151
+ ) -> Dict[str, Any]:
152
+ logging.info(f"Loading checkpoint from {filename}")
153
+ checkpoint = torch.load(filename, map_location="cpu", weights_only=False)
154
+
155
+ if model is not None:
156
+ if next(iter(checkpoint["model"])).startswith("module."):
157
+ logging.info("Loading checkpoint saved by DDP")
158
+ dst_state_dict = model.state_dict()
159
+ src_state_dict = checkpoint["model"]
160
+ for key in dst_state_dict.keys():
161
+ src_key = "{}.{}".format("module", key)
162
+ dst_state_dict[key] = src_state_dict.pop(src_key)
163
+ assert len(src_state_dict) == 0
164
+ else:
165
+ logging.info("Loading checkpoint")
166
+ dst_state_dict = checkpoint["model"]
167
+ dst_state_dict["spk_embed.weight"] = model.state_dict()["spk_embed.weight"]
168
+ embed_weight = model.state_dict()["embed.weight"]
169
+ embed_weight[:-extend_size, :] = dst_state_dict["embed.weight"]
170
+ dst_state_dict["embed.weight"] = embed_weight
171
+
172
+ model.load_state_dict(dst_state_dict, strict=strict)
173
+
174
+
175
+ def load_checkpoint_copy_proj_three_channel_alter(
176
+ filename: Path,
177
+ in_proj_key: str,
178
+ out_proj_key: str,
179
+ dim: int,
180
+ model: nn.Module,
181
+ ) -> Dict[str, Any]:
182
+ logging.info(f"Loading checkpoint from {filename}")
183
+ checkpoint = torch.load(filename, map_location="cpu", weights_only=False)
184
+
185
+ if model is not None:
186
+ if next(iter(checkpoint["model"])).startswith("module."):
187
+ logging.info("Loading checkpoint saved by DDP")
188
+
189
+ dst_state_dict = dict()
190
+ src_state_dict = checkpoint["model"]
191
+ for key in src_state_dict.keys():
192
+ dst_state_dict[key.lstrip("module.")] = src_state_dict.pop(key)
193
+ assert len(src_state_dict) == 0
194
+ else:
195
+ logging.info("Loading checkpoint")
196
+ dst_state_dict = checkpoint["model"]
197
+ keys = list(dst_state_dict.keys())
198
+ for key in keys:
199
+ if in_proj_key in key:
200
+ if "weight" in key:
201
+ weight = dst_state_dict.pop(key)
202
+ dst_state_dict[key.replace("weight", "0.weight")] = torch.cat(
203
+ [
204
+ weight[:, :dim] / 2,
205
+ weight[:, :dim] / 2,
206
+ weight[:, dim : dim * 2],
207
+ weight[:, dim * 2 :] / 2,
208
+ weight[:, dim * 2 :] / 2,
209
+ ],
210
+ dim=-1,
211
+ )
212
+ dst_state_dict[key.replace("weight", "1.weight")] = weight
213
+ if "bias" in key:
214
+ bias = dst_state_dict.pop(key)
215
+ dst_state_dict[key.replace("bias", "0.bias")] = bias
216
+ dst_state_dict[key.replace("bias", "1.bias")] = bias
217
+ if out_proj_key in key:
218
+ if "weight" in key:
219
+ weight = dst_state_dict.pop(key)
220
+ dst_state_dict[key.replace("weight", "0.weight")] = torch.cat(
221
+ [weight, weight], dim=0
222
+ )
223
+ dst_state_dict[key.replace("weight", "1.weight")] = weight
224
+ elif "bias" in key:
225
+ bias = dst_state_dict.pop(key)
226
+ dst_state_dict[key.replace("bias", "0.bias")] = torch.cat(
227
+ [bias, bias], dim=0
228
+ )
229
+ dst_state_dict[key.replace("bias", "1.bias")] = bias
230
+
231
+ model.load_state_dict(dst_state_dict, strict=True)
232
+
233
+
234
+ def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
235
+ """Find all available checkpoints in a directory.
236
+
237
+ The checkpoint filenames have the form: `checkpoint-xxx.pt`
238
+ where xxx is a numerical value.
239
+
240
+ Assume you have the following checkpoints in the folder `foo`:
241
+
242
+ - checkpoint-1.pt
243
+ - checkpoint-20.pt
244
+ - checkpoint-300.pt
245
+ - checkpoint-4000.pt
246
+
247
+ Case 1 (Return all checkpoints)::
248
+
249
+ find_checkpoints(out_dir='foo')
250
+
251
+ Case 2 (Return checkpoints newer than checkpoint-20.pt, i.e.,
252
+ checkpoint-4000.pt, checkpoint-300.pt, and checkpoint-20.pt)
253
+
254
+ find_checkpoints(out_dir='foo', iteration=20)
255
+
256
+ Case 3 (Return checkpoints older than checkpoint-20.pt, i.e.,
257
+ checkpoint-20.pt, checkpoint-1.pt)::
258
+
259
+ find_checkpoints(out_dir='foo', iteration=-20)
260
+
261
+ Args:
262
+ out_dir:
263
+ The directory where to search for checkpoints.
264
+ iteration:
265
+ If it is 0, return all available checkpoints.
266
+ If it is positive, return the checkpoints whose iteration number is
267
+ greater than or equal to `iteration`.
268
+ If it is negative, return the checkpoints whose iteration number is
269
+ less than or equal to `-iteration`.
270
+ Returns:
271
+ Return a list of checkpoint filenames, sorted in descending
272
+ order by the numerical value in the filename.
273
+ """
274
+ checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
275
+ pattern = re.compile(r"checkpoint-([0-9]+).pt")
276
+ iter_checkpoints = []
277
+ for c in checkpoints:
278
+ result = pattern.search(c)
279
+ if not result:
280
+ logging.warn(f"Invalid checkpoint filename {c}")
281
+ continue
282
+
283
+ iter_checkpoints.append((int(result.group(1)), c))
284
+
285
+ # iter_checkpoints is a list of tuples. Each tuple contains
286
+ # two elements: (iteration_number, checkpoint-iteration_number.pt)
287
+
288
+ iter_checkpoints = sorted(iter_checkpoints, reverse=True, key=lambda x: x[0])
289
+ if iteration >= 0:
290
+ ans = [ic[1] for ic in iter_checkpoints if ic[0] >= iteration]
291
+ else:
292
+ ans = [ic[1] for ic in iter_checkpoints if ic[0] <= -iteration]
293
+
294
+ return ans
295
+
296
+
297
+ def average_checkpoints_with_averaged_model(
298
+ filename_start: str,
299
+ filename_end: str,
300
+ device: torch.device = torch.device("cpu"),
301
+ ) -> Dict[str, torch.Tensor]:
302
+ """Average model parameters over the range with given
303
+ start model (excluded) and end model.
304
+
305
+ Let start = batch_idx_train of model-start;
306
+ end = batch_idx_train of model-end;
307
+ interval = end - start.
308
+ Then the average model over range from start (excluded) to end is
309
+ (1) avg = (model_end * end - model_start * start) / interval.
310
+ It can be written as
311
+ (2) avg = model_end * weight_end + model_start * weight_start,
312
+ where weight_end = end / interval,
313
+ weight_start = -start / interval = 1 - weight_end.
314
+ Since the terms `weight_end` and `weight_start` would be large
315
+ if the model has been trained for lots of batches, which would cause
316
+ overflow when multiplying the model parameters.
317
+ To avoid this, we rewrite (2) as:
318
+ (3) avg = (model_end + model_start * (weight_start / weight_end))
319
+ * weight_end
320
+
321
+ The model index could be epoch number or iteration number.
322
+
323
+ Args:
324
+ filename_start:
325
+ Checkpoint filename of the start model. We assume it
326
+ is saved by :func:`save_checkpoint`.
327
+ filename_end:
328
+ Checkpoint filename of the end model. We assume it
329
+ is saved by :func:`save_checkpoint`.
330
+ device:
331
+ Move checkpoints to this device before averaging.
332
+ """
333
+ state_dict_start = torch.load(
334
+ filename_start, map_location=device, weights_only=False
335
+ )
336
+ state_dict_end = torch.load(filename_end, map_location=device, weights_only=False)
337
+
338
+ average_period = state_dict_start["average_period"]
339
+
340
+ batch_idx_train_start = state_dict_start["batch_idx_train"]
341
+ batch_idx_train_start = (batch_idx_train_start // average_period) * average_period
342
+ batch_idx_train_end = state_dict_end["batch_idx_train"]
343
+ batch_idx_train_end = (batch_idx_train_end // average_period) * average_period
344
+ interval = batch_idx_train_end - batch_idx_train_start
345
+ assert interval > 0, interval
346
+ weight_end = batch_idx_train_end / interval
347
+ weight_start = 1 - weight_end
348
+
349
+ model_end = state_dict_end["model_avg"]
350
+ model_start = state_dict_start["model_avg"]
351
+ avg = model_end
352
+
353
+ # scale the weight to avoid overflow
354
+ average_state_dict(
355
+ state_dict_1=avg,
356
+ state_dict_2=model_start,
357
+ weight_1=1.0,
358
+ weight_2=weight_start / weight_end,
359
+ scaling_factor=weight_end,
360
+ )
361
+
362
+ return avg
363
+
364
+
365
+ def remove_checkpoints(
366
+ out_dir: Path,
367
+ topk: int,
368
+ rank: int = 0,
369
+ ):
370
+ """Remove checkpoints from the given directory.
371
+
372
+ We assume that checkpoint filename has the form `checkpoint-xxx.pt`
373
+ where xxx is a number, representing the number of processed batches
374
+ when saving that checkpoint. We sort checkpoints by filename and keep
375
+ only the `topk` checkpoints with the highest `xxx`.
376
+
377
+ Args:
378
+ out_dir:
379
+ The directory containing checkpoints to be removed.
380
+ topk:
381
+ Number of checkpoints to keep.
382
+ rank:
383
+ If using DDP for training, it is the rank of the current node.
384
+ Use 0 if no DDP is used for training.
385
+ """
386
+ assert topk >= 1, topk
387
+ if rank != 0:
388
+ return
389
+ checkpoints = find_checkpoints(out_dir)
390
+
391
+ if len(checkpoints) == 0:
392
+ logging.warn(f"No checkpoints found in {out_dir}")
393
+ return
394
+
395
+ if len(checkpoints) <= topk:
396
+ return
397
+
398
+ to_remove = checkpoints[topk:]
399
+ for c in to_remove:
400
+ os.remove(c)
401
+
402
+
403
+ def resume_checkpoint(
404
+ params: AttributeDict,
405
+ model: nn.Module,
406
+ model_avg: nn.Module,
407
+ model_ema: Optional[nn.Module] = None,
408
+ ) -> Optional[Dict[str, Any]]:
409
+ """Load checkpoint from file.
410
+
411
+ If params.start_epoch is larger than 1, it will load the checkpoint from
412
+ `params.start_epoch - 1`.
413
+
414
+ Apart from loading state dict for `model` and `optimizer` it also updates
415
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
416
+ and `best_valid_loss` in `params`.
417
+
418
+ Args:
419
+ params:
420
+ The return value of :func:`get_params`.
421
+ model:
422
+ The training model.
423
+ Returns:
424
+ Return a dict containing previously saved training info.
425
+ """
426
+ filename = params.exp_dir / f"epoch-{params.start_epoch - 1}.pt"
427
+
428
+ assert filename.is_file(), f"{filename} does not exist!"
429
+
430
+ saved_params = load_checkpoint(
431
+ filename,
432
+ model=model,
433
+ model_avg=model_avg,
434
+ model_ema=model_ema,
435
+ strict=True,
436
+ )
437
+
438
+ if params.start_epoch > 1:
439
+ keys = [
440
+ "best_train_epoch",
441
+ "best_valid_epoch",
442
+ "batch_idx_train",
443
+ "best_train_loss",
444
+ "best_valid_loss",
445
+ ]
446
+ for k in keys:
447
+ params[k] = saved_params[k]
448
+
449
+ return saved_params
450
+
451
+
452
+ def average_state_dict(
453
+ state_dict_1: Dict[str, torch.Tensor],
454
+ state_dict_2: Dict[str, torch.Tensor],
455
+ weight_1: float,
456
+ weight_2: float,
457
+ scaling_factor: float = 1.0,
458
+ ) -> Dict[str, torch.Tensor]:
459
+ """Average two state_dict with given weights:
460
+ state_dict_1 = (state_dict_1 * weight_1 + state_dict_2 * weight_2)
461
+ * scaling_factor
462
+ It is an in-place operation on state_dict_1 itself.
463
+ """
464
+ # Identify shared parameters. Two parameters are said to be shared
465
+ # if they have the same data_ptr
466
+ uniqued: Dict[int, str] = dict()
467
+ for k, v in state_dict_1.items():
468
+ v_data_ptr = v.data_ptr()
469
+ if v_data_ptr in uniqued:
470
+ continue
471
+ uniqued[v_data_ptr] = k
472
+
473
+ uniqued_names = list(uniqued.values())
474
+ for k in uniqued_names:
475
+ v = state_dict_1[k]
476
+ if torch.is_floating_point(v):
477
+ v *= weight_1
478
+ v += state_dict_2[k].to(device=state_dict_1[k].device) * weight_2
479
+ v *= scaling_factor
480
+
481
+
482
+ def update_averaged_model(
483
+ params: Dict[str, torch.Tensor],
484
+ model_cur: Union[nn.Module, DDP],
485
+ model_avg: nn.Module,
486
+ ) -> None:
487
+ """Update the averaged model:
488
+ model_avg = model_cur * (average_period / batch_idx_train)
489
+ + model_avg * ((batch_idx_train - average_period) / batch_idx_train)
490
+
491
+ Args:
492
+ params:
493
+ User defined parameters, e.g., epoch, loss.
494
+ model_cur:
495
+ The current model.
496
+ model_avg:
497
+ The averaged model to be updated.
498
+ """
499
+ weight_cur = params.average_period / params.batch_idx_train
500
+ weight_avg = 1 - weight_cur
501
+
502
+ if isinstance(model_cur, DDP):
503
+ model_cur = model_cur.module
504
+
505
+ cur = model_cur.state_dict()
506
+ avg = model_avg.state_dict()
507
+
508
+ average_state_dict(
509
+ state_dict_1=avg,
510
+ state_dict_2=cur,
511
+ weight_1=weight_avg,
512
+ weight_2=weight_cur,
513
+ )
514
+
515
+
516
+ def save_checkpoint_with_global_batch_idx(
517
+ out_dir: Path,
518
+ global_batch_idx: int,
519
+ model: Union[nn.Module, DDP],
520
+ model_avg: Optional[nn.Module] = None,
521
+ params: Optional[Dict[str, Any]] = None,
522
+ optimizer: Optional[Optimizer] = None,
523
+ scheduler: Optional[LRSchedulerType] = None,
524
+ scaler: Optional[GradScaler] = None,
525
+ sampler: Optional[CutSampler] = None,
526
+ rank: int = 0,
527
+ ):
528
+ """Save training info after processing given number of batches.
529
+
530
+ Args:
531
+ out_dir:
532
+ The directory to save the checkpoint.
533
+ global_batch_idx:
534
+ The number of batches processed so far from the very start of the
535
+ training. The saved checkpoint will have the following filename:
536
+
537
+ f'out_dir / checkpoint-{global_batch_idx}.pt'
538
+ model:
539
+ The neural network model whose `state_dict` will be saved in the
540
+ checkpoint.
541
+ model_avg:
542
+ The stored model averaged from the start of training.
543
+ params:
544
+ A dict of training configurations to be saved.
545
+ optimizer:
546
+ The optimizer used in the training. Its `state_dict` will be saved.
547
+ scheduler:
548
+ The learning rate scheduler used in the training. Its `state_dict` will
549
+ be saved.
550
+ scaler:
551
+ The scaler used for mix precision training. Its `state_dict` will
552
+ be saved.
553
+ sampler:
554
+ The sampler used in the training dataset.
555
+ rank:
556
+ The rank ID used in DDP training of the current node. Set it to 0
557
+ if DDP is not used.
558
+ """
559
+ out_dir = Path(out_dir)
560
+ out_dir.mkdir(parents=True, exist_ok=True)
561
+ filename = out_dir / f"checkpoint-{global_batch_idx}.pt"
562
+ save_checkpoint(
563
+ filename=filename,
564
+ model=model,
565
+ model_avg=model_avg,
566
+ params=params,
567
+ optimizer=optimizer,
568
+ scheduler=scheduler,
569
+ scaler=scaler,
570
+ sampler=sampler,
571
+ rank=rank,
572
+ )
zipvoice/utils/common.py ADDED
@@ -0,0 +1,604 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import collections
3
+ import json
4
+ import logging
5
+ import os
6
+ import socket
7
+ import subprocess
8
+ import sys
9
+ from collections import defaultdict
10
+ from datetime import datetime
11
+ from pathlib import Path
12
+ from typing import Any, Dict, List, Tuple, Union
13
+
14
+ import torch
15
+ from torch import distributed as dist
16
+ from torch import nn
17
+ from torch.nn.parallel import DistributedDataParallel as DDP
18
+ from torch.utils.tensorboard import SummaryWriter
19
+
20
+ Pathlike = Union[str, Path]
21
+
22
+
23
+ class AttributeDict(dict):
24
+ def __getattr__(self, key):
25
+ if key in self:
26
+ return self[key]
27
+ raise AttributeError(f"No such attribute '{key}'")
28
+
29
+ def __setattr__(self, key, value):
30
+ self[key] = value
31
+
32
+ def __delattr__(self, key):
33
+ if key in self:
34
+ del self[key]
35
+ return
36
+ raise AttributeError(f"No such attribute '{key}'")
37
+
38
+ def __str__(self, indent: int = 2):
39
+ tmp = {}
40
+ for k, v in self.items():
41
+ # PosixPath is ont JSON serializable
42
+ if isinstance(v, (Path, torch.device, torch.dtype)):
43
+ v = str(v)
44
+ tmp[k] = v
45
+ return json.dumps(tmp, indent=indent, sort_keys=True)
46
+
47
+
48
+ class MetricsTracker(collections.defaultdict):
49
+ def __init__(self):
50
+ # Passing the type 'int' to the base-class constructor
51
+ # makes undefined items default to int() which is zero.
52
+ # This class will play a role as metrics tracker.
53
+ # It can record many metrics, including but not limited to loss.
54
+ super(MetricsTracker, self).__init__(int)
55
+
56
+ def __add__(self, other: "MetricsTracker") -> "MetricsTracker":
57
+ ans = MetricsTracker()
58
+ for k, v in self.items():
59
+ ans[k] = v
60
+ for k, v in other.items():
61
+ ans[k] = ans[k] + v
62
+ return ans
63
+
64
+ def __mul__(self, alpha: float) -> "MetricsTracker":
65
+ ans = MetricsTracker()
66
+ for k, v in self.items():
67
+ ans[k] = v * alpha
68
+ return ans
69
+
70
+ def __str__(self) -> str:
71
+ ans_frames = ""
72
+ ans_utterances = ""
73
+ for k, v in self.norm_items():
74
+ norm_value = "%.4g" % v
75
+ if "utt_" not in k:
76
+ ans_frames += str(k) + "=" + str(norm_value) + ", "
77
+ else:
78
+ ans_utterances += str(k) + "=" + str(norm_value)
79
+ if k == "utt_duration":
80
+ ans_utterances += " frames, "
81
+ elif k == "utt_pad_proportion":
82
+ ans_utterances += ", "
83
+ else:
84
+ raise ValueError(f"Unexpected key: {k}")
85
+ frames = "%.2f" % self["frames"]
86
+ ans_frames += "over " + str(frames) + " frames. "
87
+ if ans_utterances != "":
88
+ utterances = "%.2f" % self["utterances"]
89
+ ans_utterances += "over " + str(utterances) + " utterances."
90
+
91
+ return ans_frames + ans_utterances
92
+
93
+ def norm_items(self) -> List[Tuple[str, float]]:
94
+ """
95
+ Returns a list of pairs, like:
96
+ [('ctc_loss', 0.1), ('att_loss', 0.07)]
97
+ """
98
+ num_frames = self["frames"] if "frames" in self else 1
99
+ num_utterances = self["utterances"] if "utterances" in self else 1
100
+ ans = []
101
+ for k, v in self.items():
102
+ if k == "frames" or k == "utterances":
103
+ continue
104
+ norm_value = (
105
+ float(v) / num_frames if "utt_" not in k else float(v) / num_utterances
106
+ )
107
+ ans.append((k, norm_value))
108
+ return ans
109
+
110
+ def reduce(self, device):
111
+ """
112
+ Reduce using torch.distributed, which I believe ensures that
113
+ all processes get the total.
114
+ """
115
+ keys = sorted(self.keys())
116
+ s = torch.tensor([float(self[k]) for k in keys], device=device)
117
+ dist.all_reduce(s, op=dist.ReduceOp.SUM)
118
+ for k, v in zip(keys, s.cpu().tolist()):
119
+ self[k] = v
120
+
121
+ def write_summary(
122
+ self,
123
+ tb_writer: SummaryWriter,
124
+ prefix: str,
125
+ batch_idx: int,
126
+ ) -> None:
127
+ """Add logging information to a TensorBoard writer.
128
+
129
+ Args:
130
+ tb_writer: a TensorBoard writer
131
+ prefix: a prefix for the name of the loss, e.g. "train/valid_",
132
+ or "train/current_"
133
+ batch_idx: The current batch index, used as the x-axis of the plot.
134
+ """
135
+ for k, v in self.norm_items():
136
+ tb_writer.add_scalar(prefix + k, v, batch_idx)
137
+
138
+
139
+ def setup_dist(
140
+ rank=None,
141
+ world_size=None,
142
+ master_port=None,
143
+ use_ddp_launch=False,
144
+ master_addr=None,
145
+ ):
146
+ """
147
+ rank and world_size are used only if use_ddp_launch is False.
148
+ """
149
+ if "MASTER_ADDR" not in os.environ:
150
+ os.environ["MASTER_ADDR"] = (
151
+ "localhost" if master_addr is None else str(master_addr)
152
+ )
153
+
154
+ if "MASTER_PORT" not in os.environ:
155
+ os.environ["MASTER_PORT"] = "12354" if master_port is None else str(master_port)
156
+
157
+ if use_ddp_launch is False:
158
+ dist.init_process_group("nccl", rank=rank, world_size=world_size)
159
+ torch.cuda.set_device(rank)
160
+ else:
161
+ dist.init_process_group("nccl")
162
+
163
+
164
+ def cleanup_dist():
165
+ dist.destroy_process_group()
166
+
167
+
168
+ def prepare_input(
169
+ params: AttributeDict,
170
+ batch: dict,
171
+ device: torch.device,
172
+ return_tokens: bool = True,
173
+ return_feature: bool = True,
174
+ return_audio: bool = False,
175
+ ):
176
+ """
177
+ Parse the features and targets of the current batch.
178
+ Args:
179
+ params:
180
+ It is returned by :func:`get_params`.
181
+ batch:
182
+ It is the return value from iterating
183
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
184
+ for the format of the `batch`.
185
+ device:
186
+ The device of Tensor.
187
+ """
188
+ return_list = []
189
+
190
+ if return_tokens:
191
+ return_list += [batch["tokens"]]
192
+
193
+ if return_feature:
194
+ features = batch["features"].to(device)
195
+ features_lens = batch["features_lens"].to(device)
196
+ return_list += [features * params.feat_scale, features_lens]
197
+
198
+ if return_audio:
199
+ return_list += [batch["audio"], batch["audio_lens"]]
200
+
201
+ return return_list
202
+
203
+
204
+ def prepare_avg_tokens_durations(features_lens, tokens_lens):
205
+ tokens_durations = []
206
+ for i in range(len(features_lens)):
207
+ utt_duration = features_lens[i]
208
+ avg_token_duration = utt_duration // tokens_lens[i]
209
+ tokens_durations.append([avg_token_duration] * tokens_lens[i])
210
+ return tokens_durations
211
+
212
+
213
+ def pad_labels(y: List[List[int]], pad_id: int, device: torch.device):
214
+ """
215
+ Pad the transcripts to the same length with zeros.
216
+
217
+ Args:
218
+ y: the transcripts, which is a list of a list
219
+
220
+ Returns:
221
+ Return a Tensor of padded transcripts.
222
+ """
223
+ y = [token_ids + [pad_id] for token_ids in y]
224
+ length = max([len(token_ids) for token_ids in y])
225
+ y = [token_ids + [pad_id] * (length - len(token_ids)) for token_ids in y]
226
+ return torch.tensor(y, dtype=torch.int64, device=device)
227
+
228
+
229
+ def get_tokens_index(durations: List[List[int]], num_frames: int) -> torch.Tensor:
230
+ """
231
+ Gets position in the transcript for each frame, i.e. the position
232
+ in the symbol-sequence to look up.
233
+
234
+ Args:
235
+ durations:
236
+ Duration of each token in transcripts.
237
+ num_frames:
238
+ The maximum frame length of the current batch.
239
+
240
+ Returns:
241
+ Return a Tensor of shape (batch_size, num_frames)
242
+ """
243
+ durations = [x + [num_frames - sum(x)] for x in durations]
244
+ batch_size = len(durations)
245
+ ans = torch.zeros(batch_size, num_frames, dtype=torch.int64)
246
+ for b in range(batch_size):
247
+ this_dur = durations[b]
248
+ cur_frame = 0
249
+ for i, d in enumerate(this_dur):
250
+ ans[b, cur_frame : cur_frame + d] = i
251
+ cur_frame += d
252
+ assert cur_frame == num_frames, (cur_frame, num_frames)
253
+ return ans
254
+
255
+
256
+ def to_int_tuple(s: Union[str, int]):
257
+ if isinstance(s, int):
258
+ return (s,)
259
+ return tuple(map(int, s.split(",")))
260
+
261
+
262
+ def get_adjusted_batch_count(params: AttributeDict) -> float:
263
+ # returns the number of batches we would have used so far if we had used the
264
+ # reference duration. This is for purposes of set_batch_count().
265
+ return (
266
+ params.batch_idx_train
267
+ * (params.max_duration * params.world_size)
268
+ / params.ref_duration
269
+ )
270
+
271
+
272
+ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
273
+ if isinstance(model, DDP):
274
+ # get underlying nn.Module
275
+ model = model.module
276
+ for name, module in model.named_modules():
277
+ if hasattr(module, "batch_count"):
278
+ module.batch_count = batch_count
279
+ if hasattr(module, "name"):
280
+ module.name = name
281
+
282
+
283
+ def condition_time_mask(
284
+ features_lens: torch.Tensor,
285
+ mask_percent: Tuple[float, float],
286
+ max_len: int = 0,
287
+ ) -> torch.Tensor:
288
+ """
289
+ Apply Time masking.
290
+ Args:
291
+ features_lens:
292
+ input tensor of shape ``(B)``
293
+ mask_size:
294
+ the width size for masking.
295
+ max_len:
296
+ the maximum length of the mask.
297
+ Returns:
298
+ Return a 2-D bool tensor (B, T), where masked positions
299
+ are filled with `True` and non-masked positions are
300
+ filled with `False`.
301
+ """
302
+ mask_size = (
303
+ torch.zeros_like(features_lens, dtype=torch.float32).uniform_(*mask_percent)
304
+ * features_lens
305
+ ).to(torch.int64)
306
+ mask_starts = (
307
+ torch.rand_like(mask_size, dtype=torch.float32) * (features_lens - mask_size)
308
+ ).to(torch.int64)
309
+ mask_ends = mask_starts + mask_size
310
+ max_len = max(max_len, features_lens.max())
311
+ seq_range = torch.arange(0, max_len, device=features_lens.device)
312
+ mask = (seq_range[None, :] >= mask_starts[:, None]) & (
313
+ seq_range[None, :] < mask_ends[:, None]
314
+ )
315
+ return mask
316
+
317
+
318
+ def condition_time_mask_suffix(
319
+ features_lens: torch.Tensor,
320
+ mask_percent: Tuple[float, float],
321
+ max_len: int = 0,
322
+ ) -> torch.Tensor:
323
+ """
324
+ Apply Time masking, mask from the end time index.
325
+ Args:
326
+ features_lens:
327
+ input tensor of shape ``(B)``
328
+ mask_size:
329
+ the width size for masking.
330
+ max_len:
331
+ the maximum length of the mask.
332
+ Returns:
333
+ Return a 2-D bool tensor (B, T), where masked positions
334
+ are filled with `True` and non-masked positions are
335
+ filled with `False`.
336
+ """
337
+ mask_size = (
338
+ torch.zeros_like(features_lens, dtype=torch.float32).uniform_(*mask_percent)
339
+ * features_lens
340
+ ).to(torch.int64)
341
+ mask_starts = (
342
+ torch.ones_like(mask_size, dtype=torch.float32) * (features_lens - mask_size)
343
+ ).to(torch.int64)
344
+ mask_ends = mask_starts + mask_size
345
+ max_len = max(max_len, features_lens.max())
346
+ seq_range = torch.arange(0, max_len, device=features_lens.device)
347
+ mask = (seq_range[None, :] >= mask_starts[:, None]) & (
348
+ seq_range[None, :] < mask_ends[:, None]
349
+ )
350
+ return mask
351
+
352
+
353
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
354
+ """
355
+ Args:
356
+ lengths:
357
+ A 1-D tensor containing sentence lengths.
358
+ max_len:
359
+ The length of masks.
360
+ Returns:
361
+ Return a 2-D bool tensor, where masked positions
362
+ are filled with `True` and non-masked positions are
363
+ filled with `False`.
364
+
365
+ >>> lengths = torch.tensor([1, 3, 2, 5])
366
+ >>> make_pad_mask(lengths)
367
+ tensor([[False, True, True, True, True],
368
+ [False, False, False, True, True],
369
+ [False, False, True, True, True],
370
+ [False, False, False, False, False]])
371
+ """
372
+ assert lengths.ndim == 1, lengths.ndim
373
+ max_len = max(max_len, lengths.max())
374
+ n = lengths.size(0)
375
+ seq_range = torch.arange(0, max_len, device=lengths.device)
376
+ expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
377
+
378
+ return expaned_lengths >= lengths.unsqueeze(-1)
379
+
380
+
381
+ def str2bool(v):
382
+ """Used in argparse.ArgumentParser.add_argument to indicate
383
+ that a type is a bool type and user can enter
384
+
385
+ - yes, true, t, y, 1, to represent True
386
+ - no, false, f, n, 0, to represent False
387
+
388
+ See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa
389
+ """
390
+ if isinstance(v, bool):
391
+ return v
392
+ if v.lower() in ("yes", "true", "t", "y", "1"):
393
+ return True
394
+ elif v.lower() in ("no", "false", "f", "n", "0"):
395
+ return False
396
+ else:
397
+ raise argparse.ArgumentTypeError("Boolean value expected.")
398
+
399
+
400
+ def setup_logger(
401
+ log_filename: Pathlike,
402
+ log_level: str = "info",
403
+ use_console: bool = True,
404
+ ) -> None:
405
+ """Setup log level.
406
+
407
+ Args:
408
+ log_filename:
409
+ The filename to save the log.
410
+ log_level:
411
+ The log level to use, e.g., "debug", "info", "warning", "error",
412
+ "critical"
413
+ use_console:
414
+ True to also print logs to console.
415
+ """
416
+ now = datetime.now()
417
+ date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
418
+ if dist.is_available() and dist.is_initialized():
419
+ world_size = dist.get_world_size()
420
+ rank = dist.get_rank()
421
+ formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa
422
+ log_filename = f"{log_filename}-{date_time}-{rank}"
423
+ else:
424
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
425
+ log_filename = f"{log_filename}-{date_time}"
426
+
427
+ os.makedirs(os.path.dirname(log_filename), exist_ok=True)
428
+
429
+ level = logging.ERROR
430
+ if log_level == "debug":
431
+ level = logging.DEBUG
432
+ elif log_level == "info":
433
+ level = logging.INFO
434
+ elif log_level == "warning":
435
+ level = logging.WARNING
436
+ elif log_level == "critical":
437
+ level = logging.CRITICAL
438
+
439
+ logging.basicConfig(
440
+ filename=log_filename,
441
+ format=formatter,
442
+ level=level,
443
+ filemode="w",
444
+ force=True,
445
+ )
446
+ if use_console:
447
+ console = logging.StreamHandler()
448
+ console.setLevel(level)
449
+ console.setFormatter(logging.Formatter(formatter))
450
+ logging.getLogger("").addHandler(console)
451
+
452
+
453
+ def get_git_sha1():
454
+ try:
455
+ git_commit = (
456
+ subprocess.run(
457
+ ["git", "rev-parse", "--short", "HEAD"],
458
+ check=True,
459
+ stdout=subprocess.PIPE,
460
+ )
461
+ .stdout.decode()
462
+ .rstrip("\n")
463
+ .strip()
464
+ )
465
+ dirty_commit = (
466
+ len(
467
+ subprocess.run(
468
+ ["git", "diff", "--shortstat"],
469
+ check=True,
470
+ stdout=subprocess.PIPE,
471
+ )
472
+ .stdout.decode()
473
+ .rstrip("\n")
474
+ .strip()
475
+ )
476
+ > 0
477
+ )
478
+ git_commit = git_commit + "-dirty" if dirty_commit else git_commit + "-clean"
479
+ except: # noqa
480
+ return None
481
+
482
+ return git_commit
483
+
484
+
485
+ def get_git_date():
486
+ try:
487
+ git_date = (
488
+ subprocess.run(
489
+ ["git", "log", "-1", "--format=%ad", "--date=local"],
490
+ check=True,
491
+ stdout=subprocess.PIPE,
492
+ )
493
+ .stdout.decode()
494
+ .rstrip("\n")
495
+ .strip()
496
+ )
497
+ except: # noqa
498
+ return None
499
+
500
+ return git_date
501
+
502
+
503
+ def get_git_branch_name():
504
+ try:
505
+ git_date = (
506
+ subprocess.run(
507
+ ["git", "rev-parse", "--abbrev-ref", "HEAD"],
508
+ check=True,
509
+ stdout=subprocess.PIPE,
510
+ )
511
+ .stdout.decode()
512
+ .rstrip("\n")
513
+ .strip()
514
+ )
515
+ except: # noqa
516
+ return None
517
+
518
+ return git_date
519
+
520
+
521
+ def get_env_info() -> Dict[str, Any]:
522
+ """Get the environment information."""
523
+ return {
524
+ "torch-version": str(torch.__version__),
525
+ "torch-cuda-available": torch.cuda.is_available(),
526
+ "torch-cuda-version": torch.version.cuda,
527
+ "python-version": sys.version[:4],
528
+ "zipvoice-git-branch": get_git_branch_name(),
529
+ "zipvoice-git-sha1": get_git_sha1(),
530
+ "zipvoice-git-date": get_git_date(),
531
+ "zipvoice-path": str(Path(__file__).resolve().parent.parent),
532
+ "hostname": socket.gethostname(),
533
+ "IP address": socket.gethostbyname(socket.gethostname()),
534
+ }
535
+
536
+
537
+ def get_parameter_groups_with_lrs(
538
+ model: nn.Module,
539
+ lr: float,
540
+ include_names: bool = False,
541
+ freeze_modules: List[str] = [],
542
+ ) -> List[dict]:
543
+ """
544
+ This is for use with the ScaledAdam optimizers (more recent versions that accept
545
+ lists of named-parameters; we can, if needed, create a version without the names).
546
+
547
+ It provides a way to specify learning-rate scales inside the module, so that if
548
+ any nn.Module in the hierarchy has a floating-point parameter 'lr_scale', it will
549
+ scale the LR of any parameters inside that module or its submodules. Note: you
550
+ can set module parameters outside the __init__ function, e.g.:
551
+ >>> a = nn.Linear(10, 10)
552
+ >>> a.lr_scale = 0.5
553
+
554
+ Returns: a list of dicts, of the following form:
555
+ if include_names == False:
556
+ [ { 'params': [ tensor1, tensor2, ... ], 'lr': 0.01 },
557
+ { 'params': [ tensor3, tensor4, ... ], 'lr': 0.005 },
558
+ ... ]
559
+ if include_names == true:
560
+ [ { 'named_params': [ (name1, tensor1, (name2, tensor2), ... ], 'lr': 0.01 },
561
+ { 'named_params': [ (name3, tensor3), (name4, tensor4), ... ], 'lr': 0.005 },
562
+ ... ]
563
+
564
+ """
565
+ # flat_lr_scale just contains the lr_scale explicitly specified
566
+ # for each prefix of the name, e.g. 'encoder.layers.3', these need
567
+ # to be multiplied for all prefix of the name of any given parameter.
568
+ flat_lr_scale = defaultdict(lambda: 1.0)
569
+ names = []
570
+ for name, m in model.named_modules():
571
+ names.append(name)
572
+ if hasattr(m, "lr_scale"):
573
+ flat_lr_scale[name] = m.lr_scale
574
+
575
+ # lr_to_parames is a dict from learning rate (floating point) to: if
576
+ # include_names == true, a list of (name, parameter) for that learning rate;
577
+ # otherwise a list of parameters for that learning rate.
578
+ lr_to_params = defaultdict(list)
579
+
580
+ for name, parameter in model.named_parameters():
581
+ split_name = name.split(".")
582
+ # caution: as a special case, if the name is '', split_name will be [ '' ].
583
+ prefix = split_name[0]
584
+ if prefix == "module": # DDP
585
+ module_name = split_name[1]
586
+ if module_name in freeze_modules:
587
+ logging.info(f"Remove {name} from parameters")
588
+ continue
589
+ else:
590
+ if prefix in freeze_modules:
591
+ logging.info(f"Remove {name} from parameters")
592
+ continue
593
+ cur_lr = lr * flat_lr_scale[prefix]
594
+ if prefix != "":
595
+ cur_lr *= flat_lr_scale[""]
596
+ for part in split_name[1:]:
597
+ prefix = ".".join([prefix, part])
598
+ cur_lr *= flat_lr_scale[prefix]
599
+ lr_to_params[cur_lr].append((name, parameter) if include_names else parameter)
600
+
601
+ if include_names:
602
+ return [{"named_params": pairs, "lr": lr} for lr, pairs in lr_to_params.items()]
603
+ else:
604
+ return [{"params": params, "lr": lr} for lr, params in lr_to_params.items()]
zipvoice/utils/diagnostics.py ADDED
@@ -0,0 +1,723 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-2024 Xiaomi Corp. (authors: Daniel Povey
2
+ # Zengwei Yao
3
+ # Mingshuang Luo,
4
+ # Zengrui Jin,)
5
+ #
6
+ # See ../LICENSE for clarification regarding multiple authors
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+
20
+ import logging
21
+ import random
22
+ from dataclasses import dataclass
23
+ from typing import Optional, Tuple
24
+
25
+ import torch
26
+ from torch import Tensor, nn
27
+
28
+
29
+ class TensorDiagnosticOptions(object):
30
+ """Options object for tensor diagnostics:
31
+
32
+ Args:
33
+ max_eig_dim:
34
+ The maximum dimension for which we print out eigenvalues
35
+ (limited for speed reasons).
36
+ """
37
+
38
+ def __init__(self, max_eig_dim: int = 512):
39
+ self.max_eig_dim = max_eig_dim
40
+
41
+ def dim_is_summarized(self, size: int):
42
+ return size > 10 and size != 31
43
+
44
+
45
+ def get_tensor_stats(
46
+ x: Tensor,
47
+ dim: int,
48
+ stats_type: str,
49
+ ) -> Tuple[Tensor, int]:
50
+ """
51
+ Returns the specified transformation of the Tensor (either x or x.abs()
52
+ or (x > 0), summed over all but the index `dim`.
53
+
54
+ Args:
55
+ x:
56
+ Tensor, tensor to be analyzed
57
+ dim:
58
+ Dimension with 0 <= dim < x.ndim
59
+ stats_type:
60
+ The stats_type includes several types:
61
+ "abs" -> take abs() before summing
62
+ "positive" -> take (x > 0) before summing
63
+ "rms" -> square before summing, we'll take sqrt later
64
+ "value" -> just sum x itself
65
+ "max", "min" -> take the maximum or minimum [over all other dims but dim]
66
+ instead of summing
67
+ "rms-sort" -> this is a bit different than the others, it's based on computing
68
+ the rms over the specified dim and returning percentiles of the result
69
+ (11 of them).
70
+ Returns:
71
+ stats: a Tensor of shape (x.shape[dim],).
72
+ count: an integer saying how many items were counted in each element
73
+ of stats.
74
+ """
75
+
76
+ if stats_type == "rms-sort":
77
+ rms = (x**2).mean(dim=dim).sqrt()
78
+ rms = rms.flatten()
79
+ rms = rms.sort()[0]
80
+ rms = rms[(torch.arange(11) * rms.numel() // 10).clamp(max=rms.numel() - 1)]
81
+ count = 1.0
82
+ return rms, count
83
+
84
+ count = x.numel() // x.shape[dim]
85
+
86
+ if stats_type == "eigs":
87
+ x = x.transpose(dim, -1)
88
+ x = x.reshape(-1, x.shape[-1])
89
+ # shape of returned tensor: (s, s),
90
+ # where s is size of dimension `dim` of original x.
91
+ return torch.matmul(x.transpose(0, 1), x), count
92
+ elif stats_type == "abs":
93
+ x = x.abs()
94
+ elif stats_type == "rms":
95
+ x = x**2
96
+ elif stats_type == "positive":
97
+ x = (x > 0).to(dtype=torch.float)
98
+ else:
99
+ assert stats_type in ["value", "max", "min"]
100
+
101
+ sum_dims = [d for d in range(x.ndim) if d != dim]
102
+ if len(sum_dims) > 0:
103
+ if stats_type == "max":
104
+ for dim in reversed(sum_dims):
105
+ x = torch.max(x, dim=dim)[0]
106
+ elif stats_type == "min":
107
+ for dim in reversed(sum_dims):
108
+ x = torch.min(x, dim=dim)[0]
109
+ else:
110
+ x = torch.sum(x, dim=sum_dims)
111
+ x = x.flatten().clone()
112
+ return x, count
113
+
114
+
115
+ @dataclass
116
+ class TensorAndCount:
117
+ tensor: Tensor
118
+ count: int
119
+
120
+
121
+ class TensorDiagnostic(object):
122
+ """This class is not directly used by the user, it is responsible for
123
+ collecting diagnostics for a module or parameter tensor of a torch.nn.Module.
124
+
125
+ Args:
126
+ opts:
127
+ Options object.
128
+ name:
129
+ The name associated with this diagnostics object, will probably be
130
+ {module_name}.X where X is "output" or "grad", or {parameter_name}.
131
+ Y where Y is param_value or param_grad.
132
+ """
133
+
134
+ def __init__(self, opts: TensorDiagnosticOptions, name: str):
135
+ self.opts = opts
136
+ self.name = name
137
+ self.class_name = None # will assign in accumulate()
138
+
139
+ self.stats = None # we'll later assign a list to self.stats.
140
+ # It's a list of dicts, indexed by dim (i.e. by the
141
+ # axis of the tensor). The dicts, in turn, are
142
+ # indexed by `stats-type` which are strings in
143
+ # ["abs", "max", "min", "positive", "value", "rms"].
144
+
145
+ # scalar_stats contains some analysis of the activations and gradients,
146
+ self.scalar_stats = None
147
+
148
+ # the keys into self.stats[dim] are strings, whose values can be
149
+ # "abs", "max", "min" ,"value", "positive", "rms", "value".
150
+ # The values e.g. self.stats[dim]["rms"] are lists of dataclass TensorAndCount,
151
+ # containing a tensor and its associated count (which is the sum of the other
152
+ # dims that we aggregated over, e.g. the number of frames and/or batch elements
153
+ # and/or channels.
154
+ # ... we actually accumulate the Tensors / counts any time we have the same-dim
155
+ # tensor, only adding a new element to the list if there was a different dim.
156
+ # if the string in the key is "eigs", if we detect a length mismatch we put None
157
+ # as the value.
158
+
159
+ def accumulate(self, x, class_name: Optional[str] = None):
160
+ """
161
+ Accumulate tensors.
162
+ """
163
+ if class_name is not None:
164
+ self.class_name = class_name
165
+ if isinstance(x, Tuple):
166
+ x = x[0]
167
+ if not isinstance(x, Tensor):
168
+ return
169
+ if x.numel() == 0: # for empty tensor
170
+ return
171
+ x = x.detach().clone()
172
+ if x.ndim == 0:
173
+ x = x.unsqueeze(0)
174
+ ndim = x.ndim
175
+ if self.stats is None:
176
+ self.stats = [dict() for _ in range(ndim)]
177
+
178
+ for dim in range(ndim):
179
+ this_dim_stats = self.stats[dim]
180
+ if ndim > 1:
181
+ # rms-sort is different from the others, it's based on summing over just
182
+ # this dim, then sorting and returning the percentiles.
183
+ stats_types = [
184
+ "abs",
185
+ "max",
186
+ "min",
187
+ "positive",
188
+ "value",
189
+ "rms",
190
+ "rms-sort",
191
+ ]
192
+ if x.shape[dim] <= self.opts.max_eig_dim:
193
+ stats_types.append("eigs")
194
+ else:
195
+ stats_types = ["value", "abs", "max", "min"]
196
+
197
+ for stats_type in stats_types:
198
+ stats, count = get_tensor_stats(x, dim, stats_type)
199
+ if stats_type not in this_dim_stats:
200
+ this_dim_stats[stats_type] = [] # list of TensorAndCount
201
+
202
+ done = False
203
+ if this_dim_stats[stats_type] is None:
204
+ # we can reach here if we detected for stats_type "eigs" that
205
+ # where was more than one different size for this dim. Then we
206
+ # disable accumulating this stats type, as it uses too much memory.
207
+ continue
208
+ for s in this_dim_stats[stats_type]:
209
+ if s.tensor.shape == stats.shape:
210
+ if stats_type == "max":
211
+ s.tensor = torch.maximum(s.tensor, stats)
212
+
213
+ elif stats_type == "min":
214
+ s.tensor = torch.minimum(s.tensor, stats)
215
+ else:
216
+ assert stats_type != "max"
217
+ s.tensor += stats
218
+ s.count += count
219
+ done = True
220
+ break
221
+ if not done:
222
+ if this_dim_stats[stats_type] != [] and stats_type == "eigs":
223
+ # >1 size encountered on this dim, e.g. it's a batch or time
224
+ # dimension, don't accumulat "eigs" stats type, it uses too much
225
+ # memory
226
+ this_dim_stats[stats_type] = None
227
+ else:
228
+ this_dim_stats[stats_type].append(TensorAndCount(stats, count))
229
+
230
+ def print_diagnostics(self):
231
+ """Print diagnostics for each dimension of the tensor."""
232
+ if self.stats is None:
233
+ print(f"Warning: the stats of {self.name} is None.")
234
+ return
235
+ for dim, this_dim_stats in enumerate(self.stats):
236
+ if "rms" in this_dim_stats and "value" in this_dim_stats:
237
+ # produce "stddev" stats, which is centered RMS.
238
+ rms_stats_list = this_dim_stats["rms"]
239
+ value_stats_list = this_dim_stats["value"]
240
+ if len(rms_stats_list) == len(value_stats_list):
241
+ stddev_stats_list = []
242
+ for r, v in zip(rms_stats_list, value_stats_list):
243
+ stddev_stats_list.append(
244
+ # r.count and v.count should be the same, but we don't check
245
+ # this.
246
+ TensorAndCount(
247
+ r.tensor - v.tensor * v.tensor / (v.count + 1.0e-20),
248
+ r.count,
249
+ )
250
+ )
251
+ this_dim_stats["stddev"] = stddev_stats_list
252
+
253
+ for stats_type, stats_list in this_dim_stats.items():
254
+ # stats_type could be "rms", "value", "abs", "eigs", "positive", "min"
255
+ # or "max". "stats_list" could be a list of TensorAndCount (one list per
256
+ # distinct tensor shape of the stats), or None
257
+ if stats_list is None:
258
+ assert stats_type == "eigs"
259
+ continue
260
+
261
+ def get_count(count):
262
+ return 1 if stats_type in ["max", "min"] else count
263
+
264
+ if len(stats_list) == 1:
265
+ stats = stats_list[0].tensor / get_count(stats_list[0].count)
266
+ else:
267
+ # a dimension that has variable size in different nnet
268
+ # forwards, e.g. a time dimension in an ASR model.
269
+ stats = torch.cat(
270
+ [x.tensor / get_count(x.count) for x in stats_list], dim=0
271
+ )
272
+
273
+ if stats_type == "eigs":
274
+ try:
275
+ if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"):
276
+ eigs, _ = torch.linalg.eigh(stats)
277
+ else:
278
+ eigs, _ = torch.symeig(stats)
279
+ stats = eigs.abs().sqrt()
280
+ except: # noqa
281
+ print("Error getting eigenvalues, trying another method.")
282
+ if hasattr(torch, "linalg") and hasattr(torch.linalg, "eig"):
283
+ eigs, _ = torch.linalg.eig(stats)
284
+ eigs = eigs.abs()
285
+ else:
286
+ eigs, _ = torch.eig(stats)
287
+ eigs = eigs.norm(dim=1)
288
+ stats = eigs.sqrt()
289
+ # sqrt so it reflects data magnitude, like stddev- not variance
290
+
291
+ if stats_type in ["rms", "stddev"]:
292
+ # we stored the square; after aggregation we need to take sqrt.
293
+ stats = stats.sqrt()
294
+
295
+ # if `summarize` we print percentiles of the stats; else,
296
+ # we print out individual elements.
297
+ summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized(
298
+ stats.numel()
299
+ )
300
+ if summarize: # usually `summarize` will be true
301
+ # print out percentiles.
302
+ stats = stats.sort()[0]
303
+ num_percentiles = 10
304
+ size = stats.numel()
305
+ percentiles = []
306
+ for i in range(num_percentiles + 1):
307
+ index = (i * (size - 1)) // num_percentiles
308
+ percentiles.append(stats[index].item())
309
+ percentiles = ["%.2g" % x for x in percentiles]
310
+ percentiles = " ".join(percentiles)
311
+ ans = f"percentiles: [{percentiles}]"
312
+ else:
313
+ ans = stats.tolist()
314
+ ans = ["%.2g" % x for x in ans]
315
+ ans = "[" + " ".join(ans) + "]"
316
+ if stats_type in ["value", "rms", "stddev", "eigs"]:
317
+ # This norm is useful because it is strictly less than the largest
318
+ # sqrt(eigenvalue) of the variance, which we print out, and shows,
319
+ # speaking in an approximate way, how much of that largest
320
+ # eigenvalue can be attributed to the mean of the distribution.
321
+ norm = (stats**2).sum().sqrt().item()
322
+ ans += f", norm={norm:.2g}"
323
+ mean = stats.mean().item()
324
+ rms = (stats**2).mean().sqrt().item()
325
+ ans += f", mean={mean:.3g}, rms={rms:.3g}"
326
+
327
+ # OK, "ans" contains the actual stats, e.g.
328
+ # ans = "percentiles: \
329
+ # [0.43 0.46 0.48 0.49 0.49 0.5 0.51 0.52 0.53 0.54 0.59], \
330
+ # mean=0.5, rms=0.5"
331
+
332
+ sizes = [x.tensor.shape[0] for x in stats_list]
333
+ size_str = (
334
+ f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}"
335
+ )
336
+ maybe_class_name = (
337
+ f" type={self.class_name}," if self.class_name is not None else ""
338
+ )
339
+ print(
340
+ f"module={self.name},{maybe_class_name} dim={dim}, size={size_str}, "
341
+ f"{stats_type} {ans}"
342
+ )
343
+
344
+
345
+ class ScalarDiagnostic(object):
346
+ """This class is not directly used by the user, it is responsible for
347
+ collecting diagnostics for a single module (subclass of torch.nn.Module) that
348
+ represents some kind of nonlinearity, e.g. ReLU, sigmoid, etc.
349
+ """
350
+
351
+ def __init__(self, opts: TensorDiagnosticOptions, name: str):
352
+ self.opts = opts
353
+ self.name = name
354
+ self.class_name = None # will assign in accumulate()
355
+ self.is_forward_pass = True
356
+
357
+ self.tick_scale = None
358
+
359
+ self.saved_inputs = []
360
+ self.is_ok = True
361
+
362
+ self.counts = None
363
+ self.sum_grad = None
364
+ self.sum_gradsq = None
365
+ self.sum_abs_grad = None
366
+
367
+ def accumulate_input(self, x: Tensor, class_name: Optional[str] = None):
368
+ """
369
+ Called in forward pass.
370
+ """
371
+ if not self.is_forward_pass:
372
+ # in case we did a forward pass without a backward pass, for some reason.
373
+ self.saved_inputs = []
374
+ self.is_forward_pass = True
375
+
376
+ if class_name is not None:
377
+ self.class_name = class_name
378
+ if not self.is_ok:
379
+ return
380
+
381
+ limit = 10
382
+ if len(self.saved_inputs) > limit:
383
+ print(
384
+ f"ERROR: forward pass called for this module over {limit} times "
385
+ f"with no backward pass. Will not accumulate scalar stats."
386
+ )
387
+ self.is_ok = False
388
+ return
389
+ self.saved_inputs.append(x)
390
+
391
+ def accumulate_output_grad(self, grad: Tensor):
392
+ if not self.is_ok:
393
+ return
394
+ if self.is_forward_pass:
395
+ self.is_forward_pass = False
396
+
397
+ last_shape = (
398
+ "n/a" if len(self.saved_inputs) == 0 else self.saved_inputs[-1].shape
399
+ )
400
+ if len(self.saved_inputs) == 0 or grad.shape != last_shape:
401
+ print(
402
+ f"ERROR: shape mismatch or no forward activation present when backward "
403
+ f"pass called: grad shape ={tuple(grad.shape)}"
404
+ f", num-saved-inputs={len(self.saved_inputs)}"
405
+ f", shape-of-last-saved-input={last_shape}"
406
+ )
407
+ self.is_ok = False
408
+ return
409
+
410
+ x = self.saved_inputs.pop()
411
+ self.process_input_and_grad(x, grad)
412
+
413
+ def process_input_and_grad(self, x: Tensor, grad: Tensor):
414
+ assert x.shape == grad.shape
415
+ x = x.flatten()
416
+ grad = grad.flatten()
417
+
418
+ num_ticks_per_side = 256
419
+
420
+ if self.tick_scale is None:
421
+ x_abs_sorted = x.abs().sort()[0]
422
+ # take the 98th percentile as the largest value we count separately.
423
+ index = int(x.numel() * 0.98)
424
+ self.tick_scale = float(x_abs_sorted[index] / num_ticks_per_side)
425
+
426
+ # integerize from tick * (-num ticks_per_side .. num_ticks_per_side - 1]
427
+ self.counts = torch.zeros(
428
+ 2 * num_ticks_per_side, dtype=torch.long, device=x.device
429
+ )
430
+ self.sum_grad = torch.zeros(
431
+ 2 * num_ticks_per_side, dtype=torch.double, device=x.device
432
+ )
433
+ # sum_gradsq is for getting error bars.
434
+ self.sum_gradsq = torch.zeros(
435
+ 2 * num_ticks_per_side, dtype=torch.double, device=x.device
436
+ )
437
+ self.sum_abs_grad = torch.zeros(
438
+ 2 * num_ticks_per_side, dtype=torch.double, device=x.device
439
+ )
440
+
441
+ # this will round down.
442
+ x = (x / self.tick_scale).to(torch.long)
443
+ x = x.clamp_(min=-num_ticks_per_side, max=num_ticks_per_side - 1)
444
+ x = x + num_ticks_per_side
445
+
446
+ self.counts.index_add_(dim=0, index=x, source=torch.ones_like(x))
447
+ self.sum_grad.index_add_(dim=0, index=x, source=grad.to(torch.double))
448
+ self.sum_gradsq.index_add_(
449
+ dim=0, index=x, source=(grad * grad).to(torch.double)
450
+ )
451
+ self.sum_abs_grad.index_add_(dim=0, index=x, source=grad.abs().to(torch.double))
452
+
453
+ def print_diagnostics(self):
454
+ """Print diagnostics."""
455
+ if self.is_ok is False or self.counts is None:
456
+ print(f"Warning: no stats accumulated for {self.name}, is_ok={self.is_ok}")
457
+ return
458
+
459
+ counts = self.counts.to("cpu")
460
+ sum_grad = self.sum_grad.to(device="cpu", dtype=torch.float32)
461
+ sum_gradsq = self.sum_gradsq.to(device="cpu", dtype=torch.float32)
462
+ sum_abs_grad = self.sum_abs_grad.to(device="cpu", dtype=torch.float32)
463
+
464
+ counts_cumsum = counts.cumsum(dim=0)
465
+ counts_tot = counts_cumsum[-1]
466
+
467
+ # subdivide the distribution up into `num_bins` intervals for analysis, for
468
+ # greater statistical significance. each bin corresponds to multiple of the
469
+ # original 'tick' intervals.
470
+ num_bins = 20
471
+
472
+ # integer division
473
+ counts_per_bin = (counts_tot // num_bins) + 1
474
+ bin_indexes = counts_cumsum // counts_per_bin
475
+ bin_indexes = bin_indexes.clamp(min=0, max=num_bins).to(torch.long)
476
+
477
+ bin_counts = torch.zeros(num_bins, dtype=torch.long)
478
+ bin_counts.index_add_(dim=0, index=bin_indexes, source=counts)
479
+ bin_grad = torch.zeros(num_bins)
480
+ bin_grad.index_add_(dim=0, index=bin_indexes, source=sum_grad)
481
+ bin_gradsq = torch.zeros(num_bins)
482
+ bin_gradsq.index_add_(dim=0, index=bin_indexes, source=sum_gradsq)
483
+ bin_abs_grad = torch.zeros(num_bins)
484
+ bin_abs_grad.index_add_(dim=0, index=bin_indexes, source=sum_abs_grad)
485
+
486
+ bin_boundary_counts = (
487
+ torch.arange(num_bins + 1, dtype=torch.long) * counts_per_bin
488
+ )
489
+ bin_tick_indexes = torch.searchsorted(counts_cumsum, bin_boundary_counts)
490
+ # boundaries are the "x" values between the bins, e.g. corresponding to the
491
+ # locations of percentiles of the distribution.
492
+ num_ticks_per_side = counts.numel() // 2
493
+ bin_boundaries = (bin_tick_indexes - num_ticks_per_side) * self.tick_scale
494
+
495
+ bin_grad = bin_grad / (bin_counts + 1)
496
+ bin_conf_interval = bin_gradsq.sqrt() / (
497
+ bin_counts + 1
498
+ ) # consider this a standard deviation.
499
+ # bin_grad / bin_abs_grad will give us a sense for how important in a practical
500
+ # sense, the gradients are.
501
+ bin_abs_grad = bin_abs_grad / (bin_counts + 1)
502
+
503
+ bin_rel_grad = bin_grad / (bin_abs_grad + 1.0e-20)
504
+ bin_conf = bin_grad / (bin_conf_interval + 1.0e-20)
505
+
506
+ def tensor_to_str(x: Tensor):
507
+ x = ["%.2g" % f for f in x]
508
+ x = "[" + " ".join(x) + "]"
509
+ return x
510
+
511
+ maybe_class_name = (
512
+ f" type={self.class_name}," if self.class_name is not None else ""
513
+ )
514
+
515
+ print(
516
+ f"module={self.name},{maybe_class_name} "
517
+ f"bin-boundaries={tensor_to_str(bin_boundaries)}, "
518
+ f"rel_grad={tensor_to_str(bin_rel_grad)}, "
519
+ f"grad_conf={tensor_to_str(bin_conf)}"
520
+ )
521
+
522
+
523
+ class ModelDiagnostic(object):
524
+ """This class stores diagnostics for all tensors in the torch.nn.Module.
525
+
526
+ Args:
527
+ opts:
528
+ Options object.
529
+ """
530
+
531
+ def __init__(self, opts: Optional[TensorDiagnosticOptions] = None):
532
+ # In this dictionary, the keys are tensors names and the values
533
+ # are corresponding TensorDiagnostic objects.
534
+ if opts is None:
535
+ self.opts = TensorDiagnosticOptions()
536
+ else:
537
+ self.opts = opts
538
+ self.diagnostics = dict()
539
+
540
+ def __getitem__(self, name: str):
541
+ T = ScalarDiagnostic if name[-7:] == ".scalar" else TensorDiagnostic
542
+ if name not in self.diagnostics:
543
+ self.diagnostics[name] = T(self.opts, name)
544
+ return self.diagnostics[name]
545
+
546
+ def print_diagnostics(self):
547
+ """Print diagnostics for each tensor."""
548
+ for k in sorted(self.diagnostics.keys()):
549
+ self.diagnostics[k].print_diagnostics()
550
+
551
+
552
+ def get_class_name(module: nn.Module):
553
+ ans = type(module).__name__
554
+ # we put the below in try blocks in case anyone is using a different version of
555
+ # these modules that might have different member names.
556
+ if ans == "Balancer" or ans == "ActivationBalancer":
557
+ try:
558
+ ans += f"[{float(module.min_positive)},{float(module.max_positive)},"
559
+ f"{float(module.min_abs)},{float(module.max_abs)}]"
560
+ except:
561
+ pass
562
+ elif ans == "AbsValuePenalizer":
563
+ try:
564
+ ans += f"[{module.limit}]"
565
+ except:
566
+ pass
567
+ return ans
568
+
569
+
570
+ def attach_diagnostics(
571
+ model: nn.Module, opts: Optional[TensorDiagnosticOptions] = None
572
+ ) -> ModelDiagnostic:
573
+ """Attach a ModelDiagnostic object to the model by
574
+ 1) registering forward hook and backward hook on each module, to accumulate
575
+ its output tensors and gradient tensors, respectively;
576
+ 2) registering backward hook on each module parameter, to accumulate its
577
+ values and gradients.
578
+
579
+ Args:
580
+ model:
581
+ the model to be analyzed.
582
+ opts:
583
+ Options object.
584
+
585
+ Returns:
586
+ The ModelDiagnostic object attached to the model.
587
+ """
588
+
589
+ ans = ModelDiagnostic(opts)
590
+ for name, module in model.named_modules():
591
+ if name == "":
592
+ name = "<top-level>"
593
+
594
+ # Setting model_diagnostic=ans and n=name below, instead of trying to
595
+ # capture the variables, ensures that we use the current values.
596
+ # (this matters for `name`, since the variable gets overwritten).
597
+ # These closures don't really capture by value, only by
598
+ # "the final value the variable got in the function" :-(
599
+ def forward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
600
+ if isinstance(_output, tuple) and len(_output) == 1:
601
+ _output = _output[0]
602
+
603
+ if isinstance(_output, Tensor) and _output.dtype in (
604
+ torch.float32,
605
+ torch.float16,
606
+ torch.float64,
607
+ ):
608
+ _model_diagnostic[f"{_name}.output"].accumulate(
609
+ _output, class_name=get_class_name(_module)
610
+ )
611
+ elif isinstance(_output, tuple):
612
+ for i, o in enumerate(_output):
613
+ if isinstance(o, Tensor) and o.dtype in (
614
+ torch.float32,
615
+ torch.float16,
616
+ torch.float64,
617
+ ):
618
+ _model_diagnostic[f"{_name}.output[{i}]"].accumulate(
619
+ o, class_name=get_class_name(_module)
620
+ )
621
+
622
+ def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
623
+ if isinstance(_output, tuple) and len(_output) == 1:
624
+ _output = _output[0]
625
+ if isinstance(_output, Tensor) and _output.dtype in (
626
+ torch.float32,
627
+ torch.float16,
628
+ torch.float64,
629
+ ):
630
+ _model_diagnostic[f"{_name}.grad"].accumulate(
631
+ _output, class_name=get_class_name(_module)
632
+ )
633
+ elif isinstance(_output, tuple):
634
+ for i, o in enumerate(_output):
635
+ if isinstance(o, Tensor) and o.dtype in (
636
+ torch.float32,
637
+ torch.float16,
638
+ torch.float64,
639
+ ):
640
+ _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(
641
+ o, class_name=get_class_name(_module)
642
+ )
643
+
644
+ module.register_forward_hook(forward_hook)
645
+ module.register_backward_hook(backward_hook)
646
+
647
+ if type(module).__name__ in [
648
+ "Sigmoid",
649
+ "Tanh",
650
+ "ReLU",
651
+ "TanSwish",
652
+ "Swish",
653
+ "DoubleSwish",
654
+ "Swoosh",
655
+ ]:
656
+ # For these specific module types, accumulate some additional diagnostics
657
+ # that can help us improve the activation function. These require a lot of
658
+ # memory, to save the forward activations, so limit this to some select
659
+ # classes. Note: this will not work correctly for all model types.
660
+ def scalar_forward_hook(
661
+ _module, _input, _output, _model_diagnostic=ans, _name=name
662
+ ):
663
+ if isinstance(_input, tuple):
664
+ (_input,) = _input
665
+ assert isinstance(_input, Tensor)
666
+ _model_diagnostic[f"{_name}.scalar"].accumulate_input(
667
+ _input, class_name=get_class_name(_module)
668
+ )
669
+
670
+ def scalar_backward_hook(
671
+ _module, _input, _output, _model_diagnostic=ans, _name=name
672
+ ):
673
+ if isinstance(_output, tuple):
674
+ (_output,) = _output
675
+ assert isinstance(_output, Tensor)
676
+ _model_diagnostic[f"{_name}.scalar"].accumulate_output_grad(_output)
677
+
678
+ module.register_forward_hook(scalar_forward_hook)
679
+ module.register_backward_hook(scalar_backward_hook)
680
+
681
+ for name, parameter in model.named_parameters():
682
+
683
+ def param_backward_hook(
684
+ grad, _parameter=parameter, _model_diagnostic=ans, _name=name
685
+ ):
686
+ _model_diagnostic[f"{_name}.param_value"].accumulate(_parameter)
687
+ _model_diagnostic[f"{_name}.param_grad"].accumulate(grad)
688
+
689
+ try:
690
+ parameter.register_hook(param_backward_hook)
691
+ except:
692
+ logging.warning(
693
+ f"Warning: could not register backward hook for parameter {name}, "
694
+ f"it might not be differentiable."
695
+ )
696
+
697
+ return ans
698
+
699
+
700
+ def _test_tensor_diagnostic():
701
+ opts = TensorDiagnosticOptions(512)
702
+
703
+ diagnostic = TensorDiagnostic(opts, "foo")
704
+
705
+ for _ in range(10):
706
+ diagnostic.accumulate(torch.randn(50, 100) * 10.0)
707
+
708
+ diagnostic.print_diagnostics()
709
+
710
+ model = nn.Sequential(nn.Linear(100, 50), nn.ReLU(), nn.Linear(50, 80))
711
+
712
+ diagnostic = attach_diagnostics(model, opts)
713
+ for _ in range(10):
714
+ T = random.randint(200, 300)
715
+ x = torch.randn(T, 100)
716
+ y = model(x)
717
+ y.sum().backward()
718
+
719
+ diagnostic.print_diagnostics()
720
+
721
+
722
+ if __name__ == "__main__":
723
+ _test_tensor_diagnostic()
zipvoice/utils/feature.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2024 Xiaomi Corp. (authors: Han Zhu)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from dataclasses import dataclass
19
+ from typing import Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torchaudio
24
+ from lhotse.features.base import FeatureExtractor, register_extractor
25
+ from lhotse.utils import Seconds, compute_num_frames
26
+
27
+
28
+ @dataclass
29
+ class VocosFbankConfig:
30
+ sampling_rate: int = 24000
31
+ n_mels: int = 100
32
+ n_fft: int = 1024
33
+ hop_length: int = 256
34
+
35
+
36
+ @register_extractor
37
+ class VocosFbank(FeatureExtractor):
38
+
39
+ name = "VocosFbank"
40
+ config_type = VocosFbankConfig
41
+
42
+ def __init__(self, num_channels: int = 1):
43
+ config = VocosFbankConfig
44
+ super().__init__(config=config)
45
+ assert num_channels in (1, 2)
46
+ self.num_channels = num_channels
47
+ self.fbank = torchaudio.transforms.MelSpectrogram(
48
+ sample_rate=self.config.sampling_rate,
49
+ n_fft=self.config.n_fft,
50
+ hop_length=self.config.hop_length,
51
+ n_mels=self.config.n_mels,
52
+ center=True,
53
+ power=1,
54
+ )
55
+
56
+ def _feature_fn(self, sample):
57
+ mel = self.fbank(sample)
58
+ logmel = mel.clamp(min=1e-7).log()
59
+
60
+ return logmel
61
+
62
+ @property
63
+ def device(self) -> Union[str, torch.device]:
64
+ return self.config.device
65
+
66
+ def feature_dim(self, sampling_rate: int) -> int:
67
+ return self.config.n_mels
68
+
69
+ def extract(
70
+ self,
71
+ samples: Union[np.ndarray, torch.Tensor],
72
+ sampling_rate: int,
73
+ ) -> Union[np.ndarray, torch.Tensor]:
74
+ # Check for sampling rate compatibility.
75
+ expected_sr = self.config.sampling_rate
76
+ assert sampling_rate == expected_sr, (
77
+ f"Mismatched sampling rate: extractor expects {expected_sr}, "
78
+ f"got {sampling_rate}"
79
+ )
80
+ is_numpy = False
81
+ if not isinstance(samples, torch.Tensor):
82
+ samples = torch.from_numpy(samples)
83
+ is_numpy = True
84
+
85
+ if len(samples.shape) == 1:
86
+ samples = samples.unsqueeze(0)
87
+ else:
88
+ assert samples.ndim == 2, samples.shape
89
+
90
+ if self.num_channels == 1:
91
+ if samples.shape[0] == 2:
92
+ samples = samples.mean(dim=0, keepdims=True)
93
+ else:
94
+ assert samples.shape[0] == 2, samples.shape
95
+
96
+ mel = self._feature_fn(samples)
97
+ # (1, n_mels, time) or (2, n_mels, time)
98
+ mel = mel.reshape(-1, mel.shape[-1]).t()
99
+ # (time, n_mels) or (time, 2 * n_mels)
100
+
101
+ num_frames = compute_num_frames(
102
+ samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate
103
+ )
104
+
105
+ if mel.shape[0] > num_frames:
106
+ mel = mel[:num_frames]
107
+ elif mel.shape[0] < num_frames:
108
+ mel = mel.unsqueeze(0)
109
+ mel = torch.nn.functional.pad(
110
+ mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
111
+ ).squeeze(0)
112
+
113
+ if is_numpy:
114
+ return mel.cpu().numpy()
115
+ else:
116
+ return mel
117
+
118
+ @property
119
+ def frame_shift(self) -> Seconds:
120
+ return self.config.hop_length / self.config.sampling_rate
zipvoice/utils/hooks.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021-2024 Xiaomi Corporation (authors: Zengwei Yao,
2
+ # Daniel Povey,
3
+ # Zengrui Jin,)
4
+ #
5
+ # See ../../LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ import logging
20
+ import random
21
+
22
+ import torch
23
+ from torch import Tensor, nn
24
+
25
+
26
+ def register_inf_check_hooks(model: nn.Module) -> None:
27
+ """Registering forward hook on each module, to check
28
+ whether its output tensors is not finite.
29
+
30
+ Args:
31
+ model:
32
+ the model to be analyzed.
33
+ """
34
+
35
+ for name, module in model.named_modules():
36
+ if name == "":
37
+ name = "<top-level>"
38
+
39
+ # default param _name is a way to capture the current value of the variable
40
+ # "name".
41
+ def forward_hook(_module, _input, _output, _name=name):
42
+ if isinstance(_output, Tensor):
43
+ try:
44
+ if not torch.isfinite(_output.to(torch.float32).sum()):
45
+ logging.warning(f"The sum of {_name}.output is not finite")
46
+ except RuntimeError: # e.g. CUDA out of memory
47
+ pass
48
+ elif isinstance(_output, tuple):
49
+ for i, o in enumerate(_output):
50
+ if isinstance(o, tuple):
51
+ o = o[0]
52
+ if not isinstance(o, Tensor):
53
+ continue
54
+ try:
55
+ if not torch.isfinite(o.to(torch.float32).sum()):
56
+ logging.warning(
57
+ f"The sum of {_name}.output[{i}] is not finite"
58
+ )
59
+ except RuntimeError: # e.g. CUDA out of memory
60
+ pass
61
+
62
+ # default param _name is a way to capture the current value of the variable
63
+ # "name".
64
+ def backward_hook(_module, _input, _output, _name=name):
65
+ if isinstance(_output, Tensor):
66
+ try:
67
+ if not torch.isfinite(_output.to(torch.float32).sum()):
68
+ logging.warning(f"The sum of {_name}.grad is not finite")
69
+ except RuntimeError: # e.g. CUDA out of memory
70
+ pass
71
+
72
+ elif isinstance(_output, tuple):
73
+ for i, o in enumerate(_output):
74
+ if isinstance(o, tuple):
75
+ o = o[0]
76
+ if not isinstance(o, Tensor):
77
+ continue
78
+ if not torch.isfinite(o.to(torch.float32).sum()):
79
+ logging.warning(f"The sum of {_name}.grad[{i}] is not finite")
80
+
81
+ module.register_forward_hook(forward_hook)
82
+ module.register_backward_hook(backward_hook)
83
+
84
+ for name, parameter in model.named_parameters():
85
+
86
+ def param_backward_hook(grad, _name=name):
87
+ if not torch.isfinite(grad.to(torch.float32).sum()):
88
+ logging.warning(f"The sum of {_name}.param_grad is not finite")
89
+
90
+ try:
91
+ parameter.register_hook(param_backward_hook)
92
+ except Exception as e:
93
+ logging.warning(
94
+ f"Warning: could not register backward hook for parameter {name}"
95
+ f" with error {e}, it might not be differentiable."
96
+ )
97
+
98
+
99
+ def _test_inf_check_hooks():
100
+ model = nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 80))
101
+
102
+ register_inf_check_hooks(model)
103
+ for _ in range(10):
104
+ T = random.randint(200, 300)
105
+ x = torch.randn(T, 100) + float("inf") * (T % 2)
106
+ y = model(x)
107
+ y.sum().backward()
108
+
109
+
110
+ if __name__ == "__main__":
111
+ _test_inf_check_hooks()
zipvoice/utils/lr_scheduler.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
2
+ #
3
+ # See ../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import logging
18
+ from typing import List, Optional, Union
19
+
20
+ import torch
21
+ from torch.optim import Optimizer
22
+
23
+
24
+ class LRScheduler(object):
25
+ """
26
+ Base-class for learning rate schedulers where the learning-rate depends on both the
27
+ batch and the epoch.
28
+ """
29
+
30
+ def __init__(self, optimizer: Optimizer, verbose: bool = False):
31
+ # Attach optimizer
32
+ if not isinstance(optimizer, Optimizer):
33
+ raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__))
34
+ self.optimizer = optimizer
35
+ self.verbose = verbose
36
+
37
+ for group in optimizer.param_groups:
38
+ group.setdefault("base_lr", group["lr"])
39
+
40
+ self.base_lrs = [group["base_lr"] for group in optimizer.param_groups]
41
+
42
+ self.epoch = 0
43
+ self.batch = 0
44
+
45
+ def state_dict(self):
46
+ """Returns the state of the scheduler as a :class:`dict`.
47
+
48
+ It contains an entry for every variable in self.__dict__ which
49
+ is not the optimizer.
50
+ """
51
+ return {
52
+ # the user might try to override the base_lr, so don't include this in the
53
+ # state. previously they were included.
54
+ # "base_lrs": self.base_lrs,
55
+ "epoch": self.epoch,
56
+ "batch": self.batch,
57
+ }
58
+
59
+ def load_state_dict(self, state_dict):
60
+ """Loads the schedulers state.
61
+
62
+ Args:
63
+ state_dict (dict): scheduler state. Should be an object returned
64
+ from a call to :meth:`state_dict`.
65
+ """
66
+ # the things with base_lrs are a work-around for a previous problem
67
+ # where base_lrs were written with the state dict.
68
+ base_lrs = self.base_lrs
69
+ self.__dict__.update(state_dict)
70
+ self.base_lrs = base_lrs
71
+
72
+ def get_last_lr(self) -> List[float]:
73
+ """Return last computed learning rate by current scheduler.
74
+ Will be a list of float."""
75
+ return self._last_lr
76
+
77
+ def get_lr(self):
78
+ # Compute list of learning rates from self.epoch and self.batch and
79
+ # self.base_lrs; this must be overloaded by the user.
80
+ # e.g. return [some_formula(self.batch, self.epoch, base_lr)
81
+ # for base_lr in self.base_lrs ]
82
+ raise NotImplementedError
83
+
84
+ def step_batch(self, batch: Optional[int] = None) -> None:
85
+ # Step the batch index, or just set it. If `batch` is specified, it
86
+ # must be the batch index from the start of training, i.e. summed over
87
+ # all epochs.
88
+ # You can call this in any order; if you don't provide 'batch', it should
89
+ # of course be called once per batch.
90
+ if batch is not None:
91
+ self.batch = batch
92
+ else:
93
+ self.batch = self.batch + 1
94
+ self._set_lrs()
95
+
96
+ def step_epoch(self, epoch: Optional[int] = None):
97
+ # Step the epoch index, or just set it. If you provide the 'epoch' arg, you
98
+ # should call this at the start of the epoch; if you don't provide the 'epoch'
99
+ # arg, you should call it at the end of the epoch.
100
+ if epoch is not None:
101
+ self.epoch = epoch
102
+ else:
103
+ self.epoch = self.epoch + 1
104
+ self._set_lrs()
105
+
106
+ def _set_lrs(self):
107
+ values = self.get_lr()
108
+ assert len(values) == len(self.optimizer.param_groups)
109
+
110
+ for i, data in enumerate(zip(self.optimizer.param_groups, values)):
111
+ param_group, lr = data
112
+ param_group["lr"] = lr
113
+ self.print_lr(self.verbose, i, lr)
114
+ self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
115
+
116
+ def print_lr(self, is_verbose, group, lr):
117
+ """Display the current learning rate."""
118
+ if is_verbose:
119
+ logging.warning(
120
+ f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
121
+ f" of group {group} to {lr:.4e}."
122
+ )
123
+
124
+
125
+ class Eden(LRScheduler):
126
+ """
127
+ Eden scheduler.
128
+ The basic formula (before warmup) is:
129
+ lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 *
130
+ (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup
131
+ where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches
132
+ and then stays constant at 1.
133
+
134
+ If you don't have the concept of epochs, or one epoch takes a very long time,
135
+ you can replace the notion of 'epoch' with some measure of the amount of data
136
+ processed, e.g. hours of data or frames of data, with 'lr_epochs' being set to
137
+ some measure representing "quite a lot of data": say, one fifth or one third
138
+ of an entire training run, but it doesn't matter much. You could also use
139
+ Eden2 which has only the notion of batches.
140
+
141
+ We suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
142
+
143
+ Args:
144
+ optimizer: the optimizer to change the learning rates on
145
+ lr_batches: the number of batches after which we start significantly
146
+ decreasing the learning rate, suggest 5000.
147
+ lr_epochs: the number of epochs after which we start significantly
148
+ decreasing the learning rate, suggest 6 if you plan to do e.g.
149
+ 20 to 40 epochs, but may need smaller number if dataset is huge
150
+ and you will do few epochs.
151
+ """
152
+
153
+ def __init__(
154
+ self,
155
+ optimizer: Optimizer,
156
+ lr_batches: Union[int, float],
157
+ lr_epochs: Union[int, float],
158
+ warmup_batches: Union[int, float] = 500.0,
159
+ warmup_start: float = 0.5,
160
+ verbose: bool = False,
161
+ ):
162
+ super(Eden, self).__init__(optimizer, verbose)
163
+ self.lr_batches = lr_batches
164
+ self.lr_epochs = lr_epochs
165
+ self.warmup_batches = warmup_batches
166
+
167
+ assert 0.0 <= warmup_start <= 1.0, warmup_start
168
+ self.warmup_start = warmup_start
169
+
170
+ def get_lr(self):
171
+ factor = (
172
+ (self.batch**2 + self.lr_batches**2) / self.lr_batches**2
173
+ ) ** -0.25 * (
174
+ ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25
175
+ )
176
+ warmup_factor = (
177
+ 1.0
178
+ if self.batch >= self.warmup_batches
179
+ else self.warmup_start
180
+ + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
181
+ # else 0.5 + 0.5 * (self.batch / self.warmup_batches)
182
+ )
183
+
184
+ return [x * factor * warmup_factor for x in self.base_lrs]
185
+
186
+
187
+ class FixedLRScheduler(LRScheduler):
188
+ """
189
+ Fixed learning rate scheduler.
190
+
191
+ Args:
192
+ optimizer: the optimizer to change the learning rates on
193
+ """
194
+
195
+ def __init__(
196
+ self,
197
+ optimizer: Optimizer,
198
+ verbose: bool = False,
199
+ ):
200
+ super(FixedLRScheduler, self).__init__(optimizer, verbose)
201
+
202
+ def get_lr(self):
203
+
204
+ return [x for x in self.base_lrs]
205
+
206
+
207
+ def _test_eden():
208
+ m = torch.nn.Linear(100, 100)
209
+ from zipvoice.utils.optim import ScaledAdam
210
+
211
+ optim = ScaledAdam(m.parameters(), lr=0.03)
212
+
213
+ scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True)
214
+
215
+ for epoch in range(10):
216
+ scheduler.step_epoch(epoch) # sets epoch to `epoch`
217
+
218
+ for step in range(20):
219
+ x = torch.randn(200, 100).detach()
220
+ x.requires_grad = True
221
+ y = m(x)
222
+ dy = torch.randn(200, 100).detach()
223
+ f = (y * dy).sum()
224
+ f.backward()
225
+
226
+ optim.step()
227
+ scheduler.step_batch()
228
+ optim.zero_grad()
229
+
230
+ logging.info(f"last lr = {scheduler.get_last_lr()}")
231
+ logging.info(f"state dict = {scheduler.state_dict()}")
232
+
233
+
234
+ if __name__ == "__main__":
235
+ torch.set_num_threads(1)
236
+ torch.set_num_interop_threads(1)
237
+ logging.getLogger().setLevel(logging.INFO)
238
+ import subprocess
239
+
240
+ s = subprocess.check_output(
241
+ "git status -uno .; git log -1; git diff HEAD .", shell=True
242
+ )
243
+ logging.info(s)
244
+
245
+ _test_eden()
zipvoice/utils/optim.py ADDED
@@ -0,0 +1,868 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
2
+ #
3
+ # See ../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import contextlib
18
+ import logging
19
+ from collections import defaultdict
20
+ from typing import Dict, List, Tuple
21
+
22
+ import torch
23
+ from lhotse.utils import fix_random_seed
24
+ from torch import Tensor
25
+ from torch.optim import Optimizer
26
+
27
+
28
+ class BatchedOptimizer(Optimizer):
29
+ """
30
+ This class adds to class Optimizer the capability to optimize parameters in batches:
31
+ it will stack the parameters and their grads for you so the optimizer can work
32
+ on tensors with an extra leading dimension. This is intended for speed with GPUs,
33
+ as it reduces the number of kernels launched in the optimizer.
34
+
35
+ Args:
36
+ params:
37
+ """
38
+
39
+ def __init__(self, params, defaults):
40
+ super(BatchedOptimizer, self).__init__(params, defaults)
41
+
42
+ @contextlib.contextmanager
43
+ def batched_params(self, param_group, group_params_names):
44
+ """
45
+ This function returns (technically, yields) a list of
46
+ of tuples (p, state), where
47
+ p is a `fake` parameter that is stacked (over axis 0) from real parameters
48
+ that share the same shape, and its gradient is also stacked;
49
+ `state` is the state corresponding to this batch of parameters
50
+ (it will be physically located in the "state" for one of the real
51
+ parameters, the last one that has any particular shape and dtype).
52
+
53
+ This function is decorated as a context manager so that it can
54
+ write parameters back to their "real" locations.
55
+
56
+ The idea is, instead of doing:
57
+ <code>
58
+ for p in group["params"]:
59
+ state = self.state[p]
60
+ ...
61
+ </code>
62
+ you can do:
63
+ <code>
64
+ with self.batched_params(group["params"]) as batches:
65
+ for p, state, p_names in batches:
66
+ ...
67
+ </code>
68
+
69
+ Args:
70
+ group: a parameter group, which is a list of parameters; should be
71
+ one of self.param_groups.
72
+ group_params_names: name for each parameter in group,
73
+ which is List[str].
74
+ """
75
+ batches = defaultdict(
76
+ list
77
+ ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
78
+ batches_names = defaultdict(
79
+ list
80
+ ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
81
+
82
+ assert len(param_group) == len(group_params_names)
83
+ for p, named_p in zip(param_group, group_params_names):
84
+ key = (str(p.dtype), *p.shape)
85
+ batches[key].append(p)
86
+ batches_names[key].append(named_p)
87
+
88
+ batches_names_keys = list(batches_names.keys())
89
+ sorted_idx = sorted(
90
+ range(len(batches_names)), key=lambda i: batches_names_keys[i]
91
+ )
92
+ batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
93
+ batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
94
+
95
+ stacked_params_dict = dict()
96
+
97
+ # turn batches into a list, in deterministic order.
98
+ # tuples will contain tuples of (stacked_param, state, stacked_params_names),
99
+ # one for each batch in `batches`.
100
+ tuples = []
101
+
102
+ for batch, batch_names in zip(batches, batches_names):
103
+ p = batch[0]
104
+ # we arbitrarily store the state in the
105
+ # state corresponding to the 1st parameter in the
106
+ # group. class Optimizer will take care of saving/loading state.
107
+ state = self.state[p]
108
+ p_stacked = torch.stack(batch)
109
+ grad = torch.stack(
110
+ [torch.zeros_like(p) if p.grad is None else p.grad for p in batch]
111
+ )
112
+ p_stacked.grad = grad
113
+ stacked_params_dict[key] = p_stacked
114
+ tuples.append((p_stacked, state, batch_names))
115
+
116
+ yield tuples # <-- calling code will do the actual optimization here!
117
+
118
+ for (stacked_params, _state, _names), batch in zip(tuples, batches):
119
+ for i, p in enumerate(batch): # batch is list of Parameter
120
+ p.copy_(stacked_params[i])
121
+
122
+
123
+ def basic_step(group, p, state, grad):
124
+ # computes basic Adam update using beta2 (dividing by gradient stddev) only. no
125
+ # momentum yet.
126
+ lr = group["lr"]
127
+ if p.numel() == p.shape[0]:
128
+ lr = lr * group["scalar_lr_scale"]
129
+ beta2 = group["betas"][1]
130
+ eps = group["eps"]
131
+ # p shape: (batch_size,) or (batch_size, 1, [1,..])
132
+ try:
133
+ exp_avg_sq = state[
134
+ "exp_avg_sq"
135
+ ] # shape: (batch_size,) or (batch_size, 1, [1,..])
136
+ except KeyError:
137
+ exp_avg_sq = torch.zeros(*p.shape, device=p.device, dtype=torch.float)
138
+ state["exp_avg_sq"] = exp_avg_sq
139
+
140
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
141
+
142
+ # bias_correction2 is like in Adam.
143
+ # slower update at the start will help stability anyway.
144
+ bias_correction2 = 1 - beta2 ** (state["step"] + 1)
145
+ if bias_correction2 < 0.99:
146
+ # note: not in-place.
147
+ exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
148
+ denom = exp_avg_sq.sqrt().add_(eps)
149
+
150
+ return -lr * grad / denom
151
+
152
+
153
+ def scaling_step(group, p, state, grad):
154
+ delta = basic_step(group, p, state, grad)
155
+ if p.numel() == p.shape[0]:
156
+ return delta
157
+ # there is no scaling for scalar parameters.
158
+ # (p.shape[0] is the batch of parameters.)
159
+
160
+ step = state["step"]
161
+ size_update_period = group["size_update_period"]
162
+
163
+ try:
164
+ param_rms = state["param_rms"]
165
+ scale_grads = state["scale_grads"]
166
+ scale_exp_avg_sq = state["scale_exp_avg_sq"]
167
+ except KeyError:
168
+ # we know p.ndim > 1 because we'd have returned above if not, so don't worry
169
+ # about the speial case of dim=[] that pytorch treats inconsistently.
170
+ param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
171
+ param_rms = param_rms.to(torch.float)
172
+ scale_exp_avg_sq = torch.zeros_like(param_rms)
173
+ scale_grads = torch.zeros(
174
+ size_update_period,
175
+ *param_rms.shape,
176
+ dtype=torch.float,
177
+ device=p.device,
178
+ )
179
+ state["param_rms"] = param_rms
180
+ state["scale_grads"] = scale_grads
181
+ state["scale_exp_avg_sq"] = scale_exp_avg_sq
182
+
183
+ # on every step, update the gradient w.r.t. the scale of the parameter, we
184
+ # store these as a batch and periodically update the size (for speed only, to
185
+ # avoid too many operations).
186
+ scale_grads[step % size_update_period] = (p * grad).sum(
187
+ dim=list(range(1, p.ndim)), keepdim=True
188
+ )
189
+
190
+ # periodically recompute the value of param_rms.
191
+ if step % size_update_period == size_update_period - 1:
192
+ param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
193
+
194
+ param_min_rms = group["param_min_rms"]
195
+
196
+ # scale the step size by param_rms. This is the most important "scaling" part of
197
+ # ScaledAdam
198
+ delta *= param_rms.clamp(min=param_min_rms)
199
+
200
+ if step % size_update_period == size_update_period - 1 and step > 0:
201
+ # This block updates the size of parameter by adding a step ("delta") value in
202
+ # the direction of either shrinking or growing it.
203
+ beta2 = group["betas"][1]
204
+ size_lr = group["lr"] * group["scalar_lr_scale"]
205
+ param_max_rms = group["param_max_rms"]
206
+ eps = group["eps"]
207
+ # correct beta2 for the size update period: we will have
208
+ # faster decay at this level.
209
+ beta2_corr = beta2**size_update_period
210
+ scale_exp_avg_sq.mul_(beta2_corr).add_(
211
+ (scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
212
+ alpha=1 - beta2_corr,
213
+ ) # shape is (batch_size, 1, 1, ...)
214
+
215
+ # The 1st time we reach here is when size_step == 1.
216
+ size_step = (step + 1) // size_update_period
217
+ bias_correction2 = 1 - beta2_corr**size_step
218
+
219
+ denom = scale_exp_avg_sq.sqrt() + eps
220
+
221
+ scale_step = (
222
+ -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom
223
+ )
224
+
225
+ is_too_small = param_rms < param_min_rms
226
+
227
+ # when the param gets too small, just don't shrink it any further.
228
+ scale_step.masked_fill_(is_too_small, 0.0)
229
+
230
+ # The following may help prevent instability: don't allow the scale step to be
231
+ # too large in either direction.
232
+ scale_step.clamp_(min=-0.1, max=0.1)
233
+
234
+ # and ensure the parameter rms after update never exceeds param_max_rms.
235
+ # We have to look at the trained model for parameters at or around the
236
+ # param_max_rms, because sometimes they can indicate a problem with the
237
+ # topology or settings.
238
+ scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms)
239
+
240
+ delta.add_(p * scale_step)
241
+
242
+ return delta
243
+
244
+
245
+ def momentum_step(group, p, state, grad):
246
+ delta = scaling_step(group, p, state, grad)
247
+ beta1 = group["betas"][0]
248
+ try:
249
+ stored_delta = state["delta"]
250
+ except KeyError:
251
+ stored_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float)
252
+ state["delta"] = stored_delta
253
+ stored_delta.mul_(beta1)
254
+ stored_delta.add_(delta, alpha=(1 - beta1))
255
+ # we don't bother doing the "bias correction" part of Adam for beta1 because this is
256
+ # just an edge effect that affects the first 10 or so batches; and the effect of not
257
+ # doing it is just to do a slower update for the first few batches, which will help
258
+ # stability.
259
+ return stored_delta
260
+
261
+
262
+ class ScaledAdam(BatchedOptimizer):
263
+ """
264
+ Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
265
+ proportional to the norm of that parameter; and also learn the scale of the
266
+ parameter, in log space, subject to upper and lower limits (as if we had factored
267
+ each parameter as param = underlying_param * log_scale.exp())
268
+
269
+
270
+ Args:
271
+ params: The parameters or param_groups to optimize (like other Optimizer
272
+ subclasses) Unlike common optimizers, which accept
273
+ model.parameters() or groups of parameters(), this optimizer
274
+ could accept model.named_parameters() or groups of
275
+ named_parameters(). See comments of function
276
+ _get_names_of_parameters for its 4 possible cases.
277
+ lr: The learning rate. We will typically use a learning rate schedule
278
+ that starts at 0.03 and decreases over time, i.e. much higher
279
+ than other common optimizers.
280
+ clipping_scale: (e.g. 2.0)
281
+ A scale for gradient-clipping: if specified, the normalized gradients
282
+ over the whole model will be clipped to have 2-norm equal to
283
+ `clipping_scale` times the median 2-norm over the most recent period
284
+ of `clipping_update_period` minibatches. By "normalized gradients",
285
+ we mean after multiplying by the rms parameter value for this tensor
286
+ [for non-scalars]; this is appropriate because our update is scaled
287
+ by this quantity.
288
+ betas: beta1,beta2 are momentum constants for regular momentum, and moving
289
+ sum-sq grad. Must satisfy 0 < beta <= beta2 < 1.
290
+ scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
291
+ scale of each parameter tensor and scalar parameters of the mode..
292
+ If each parameter were decomposed as p * p_scale.exp(),
293
+ where (p**2).mean().sqrt() == 1.0, scalar_lr_scale would be a the
294
+ scaling factor on the learning rate of p_scale.
295
+ eps: A general-purpose epsilon to prevent division by zero
296
+ param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
297
+ learning the scale on the parameters (we'll constrain the rms of
298
+ each non-scalar parameter tensor to be >= this value)
299
+ param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
300
+ learning the scale on the parameters (we'll constrain the rms of
301
+ each non-scalar parameter tensor to be <= this value)
302
+ scalar_max: Maximum absolute value for scalar parameters (applicable if your
303
+ model has any parameters with numel() == 1).
304
+ size_update_period: The periodicity, in steps, with which we update the size (scale)
305
+ of the parameter tensor. This is provided to save a little time
306
+ in the update.
307
+ clipping_update_period: if clipping_scale is specified, this is the period
308
+ """
309
+
310
+ def __init__(
311
+ self,
312
+ params,
313
+ lr=3e-02,
314
+ clipping_scale=None,
315
+ betas=(0.9, 0.98),
316
+ scalar_lr_scale=0.1,
317
+ eps=1.0e-08,
318
+ param_min_rms=1.0e-05,
319
+ param_max_rms=3.0,
320
+ scalar_max=10.0,
321
+ size_update_period=4,
322
+ clipping_update_period=100,
323
+ ):
324
+
325
+ defaults = dict(
326
+ lr=lr,
327
+ clipping_scale=clipping_scale,
328
+ betas=betas,
329
+ scalar_lr_scale=scalar_lr_scale,
330
+ eps=eps,
331
+ param_min_rms=param_min_rms,
332
+ param_max_rms=param_max_rms,
333
+ scalar_max=scalar_max,
334
+ size_update_period=size_update_period,
335
+ clipping_update_period=clipping_update_period,
336
+ )
337
+
338
+ # If params only contains parameters or group of parameters,
339
+ # i.e when parameter names are not given,
340
+ # this flag will be set to False in funciton _get_names_of_parameters.
341
+ self.show_dominant_parameters = True
342
+ param_groups, parameters_names = self._get_names_of_parameters(params)
343
+ super(ScaledAdam, self).__init__(param_groups, defaults)
344
+ assert len(self.param_groups) == len(parameters_names)
345
+ self.parameters_names = parameters_names
346
+
347
+ def _get_names_of_parameters(
348
+ self, params_or_named_params
349
+ ) -> Tuple[List[Dict], List[List[str]]]:
350
+ """
351
+ Args:
352
+ params_or_named_params: according to the way ScaledAdam is initialized
353
+ in train.py, this argument could be one of following 4 cases,
354
+ case 1, a generator of parameter, e.g.:
355
+ optimizer = ScaledAdam(model.parameters(), lr=params.base_lr,
356
+ clipping_scale=3.0)
357
+
358
+ case 2, a list of parameter groups with different config, e.g.:
359
+ model_param_groups = [
360
+ {'params': model.encoder.parameters(), 'lr': 0.05},
361
+ {'params': model.decoder.parameters(), 'lr': 0.01},
362
+ {'params': model.joiner.parameters(), 'lr': 0.03},
363
+ ]
364
+ optimizer = ScaledAdam(model_param_groups, lr=params.base_lr,
365
+ clipping_scale=3.0)
366
+
367
+ case 3, a generator of named_parameter, e.g.:
368
+ optimizer = ScaledAdam(model.named_parameters(), lr=params.base_lr,
369
+ clipping_scale=3.0)
370
+
371
+ case 4, a list of named_parameter groups with different config, e.g.:
372
+ model_named_param_groups = [
373
+ {'named_params': model.encoder.named_parameters(), 'lr': 0.05},
374
+ {'named_params': model.decoder.named_parameters(), 'lr': 0.01},
375
+ {'named_params': model.joiner.named_parameters(), 'lr': 0.03},
376
+ ]
377
+ optimizer = ScaledAdam(model_named_param_groups, lr=params.base_lr,
378
+ clipping_scale=3.0)
379
+
380
+ For case 1 and case 2, input params is used to initialize the underlying
381
+ torch.optimizer.
382
+ For case 3 and case 4, firstly, names and params are extracted from input
383
+ named_params, then, these extracted params are used to initialize the
384
+ underlying torch.optimizer, and these extracted names are mainly used by
385
+ function `_show_gradient_dominating_parameter`
386
+
387
+ Returns:
388
+ Returns a tuple containing 2 elements:
389
+ - `param_groups` with type List[Dict], each Dict element is a parameter
390
+ group. An example of `param_groups` could be:
391
+ [
392
+ {'params': `one iterable of Parameter`, 'lr': 0.05},
393
+ {'params': `another iterable of Parameter`, 'lr': 0.08},
394
+ {'params': `a third iterable of Parameter`, 'lr': 0.1},
395
+ ]
396
+ - `param_gruops_names` with type List[List[str]],
397
+ each `List[str]` is for a group['params'] in param_groups,
398
+ and each `str` is the name of a parameter.
399
+ A dummy name "foo" is related to each parameter,
400
+ if input are params without names, i.e. case 1 or case 2.
401
+ """
402
+ # variable naming convention in this function:
403
+ # p is short for param.
404
+ # np is short for named_param.
405
+ # p_or_np is short for param_or_named_param.
406
+ # cur is short for current.
407
+ # group is a dict,
408
+ # e.g. {'params': iterable of parameter, 'lr': 0.05, other fields}.
409
+ # groups is a List[group]
410
+
411
+ iterable_or_groups = list(params_or_named_params)
412
+ if len(iterable_or_groups) == 0:
413
+ raise ValueError("optimizer got an empty parameter list")
414
+
415
+ # The first value of returned tuple. A list of dicts containing at
416
+ # least 'params' as a key.
417
+ param_groups = []
418
+
419
+ # The second value of returned tuple,
420
+ # a List[List[str]], each sub-List is for a group.
421
+ param_groups_names = []
422
+
423
+ if not isinstance(iterable_or_groups[0], dict):
424
+ # case 1 or case 3,
425
+ # the input is an iterable of parameter or named parameter.
426
+ param_iterable_cur_group = []
427
+ param_names_cur_group = []
428
+ for p_or_np in iterable_or_groups:
429
+ if isinstance(p_or_np, tuple):
430
+ # case 3
431
+ name, param = p_or_np
432
+ else:
433
+ # case 1
434
+ assert isinstance(p_or_np, torch.Tensor)
435
+ param = p_or_np
436
+ # Assign a dummy name as a placeholder
437
+ name = "foo"
438
+ self.show_dominant_parameters = False
439
+ param_iterable_cur_group.append(param)
440
+ param_names_cur_group.append(name)
441
+ param_groups.append({"params": param_iterable_cur_group})
442
+ param_groups_names.append(param_names_cur_group)
443
+ else:
444
+ # case 2 or case 4
445
+ # the input is groups of parameter or named parameter.
446
+ for cur_group in iterable_or_groups:
447
+ if "named_params" in cur_group:
448
+ name_list = [x[0] for x in cur_group["named_params"]]
449
+ p_list = [x[1] for x in cur_group["named_params"]]
450
+ del cur_group["named_params"]
451
+ cur_group["params"] = p_list
452
+ else:
453
+ assert "params" in cur_group
454
+ name_list = ["foo" for _ in cur_group["params"]]
455
+ param_groups.append(cur_group)
456
+ param_groups_names.append(name_list)
457
+
458
+ return param_groups, param_groups_names
459
+
460
+ def __setstate__(self, state):
461
+ super(ScaledAdam, self).__setstate__(state)
462
+
463
+ @torch.no_grad()
464
+ def step(self, closure=None):
465
+ """Performs a single optimization step.
466
+
467
+ Arguments:
468
+ closure (callable, optional): A closure that reevaluates the model
469
+ and returns the loss.
470
+ """
471
+ loss = None
472
+ if closure is not None:
473
+ with torch.enable_grad():
474
+ loss = closure()
475
+
476
+ for group, group_params_names in zip(self.param_groups, self.parameters_names):
477
+
478
+ with self.batched_params(group["params"], group_params_names) as batches:
479
+
480
+ # batches is list of pairs (stacked_param, state). stacked_param is
481
+ # like a regular parameter, and will have a .grad, but the 1st dim
482
+ # corresponds to a stacking dim, it is not a real dim.
483
+
484
+ if (
485
+ len(batches[0][1]) == 0
486
+ ): # if len(first state) == 0: not yet initialized
487
+ clipping_scale = 1
488
+ else:
489
+ clipping_scale = self._get_clipping_scale(group, batches)
490
+
491
+ for p, state, _ in batches:
492
+ # Perform optimization step.
493
+ # grad is not going to be None, we handled that when creating the
494
+ # batches.
495
+ grad = p.grad
496
+ if grad.is_sparse:
497
+ raise RuntimeError(
498
+ "ScaledAdam optimizer does not support sparse gradients"
499
+ )
500
+
501
+ try:
502
+ cur_step = state["step"]
503
+ except KeyError:
504
+ state["step"] = 0
505
+ cur_step = 0
506
+
507
+ grad = (
508
+ p.grad if clipping_scale == 1.0 else p.grad.mul_(clipping_scale)
509
+ )
510
+ p += momentum_step(group, p.detach(), state, grad)
511
+
512
+ if p.numel() == p.shape[0]: # scalar parameter
513
+ scalar_max = group["scalar_max"]
514
+ p.clamp_(min=-scalar_max, max=scalar_max)
515
+
516
+ state["step"] = cur_step + 1
517
+
518
+ return loss
519
+
520
+ def _get_clipping_scale(
521
+ self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]
522
+ ) -> float:
523
+ """
524
+ Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will
525
+ scale the gradients by this amount before applying the rest of the update.
526
+
527
+ Args:
528
+ group: the parameter group, an item in self.param_groups
529
+ tuples: a list of tuples of (param, state, param_names)
530
+ where param is a batched set of parameters,
531
+ with a .grad (1st dim is batch dim)
532
+ and state is the state-dict where optimization parameters are kept.
533
+ param_names is a List[str] while each str is name for a parameter
534
+ in batched set of parameters "param".
535
+ """
536
+ assert len(tuples) >= 1
537
+ clipping_scale = group["clipping_scale"]
538
+ (first_p, first_state, _) = tuples[0]
539
+ step = first_state["step"]
540
+ if clipping_scale is None or step == 0:
541
+ # no clipping. return early on step == 0 because the other
542
+ # parameters' state won't have been initialized yet.
543
+ return 1.0
544
+ clipping_update_period = group["clipping_update_period"]
545
+ scalar_lr_scale = group["scalar_lr_scale"]
546
+
547
+ tot_sumsq = torch.tensor(0.0, device=first_p.device)
548
+ for p, state, param_names in tuples:
549
+ grad = p.grad
550
+ if grad.is_sparse:
551
+ raise RuntimeError(
552
+ "ScaledAdam optimizer does not support sparse gradients"
553
+ )
554
+ if p.numel() == p.shape[0]: # a batch of scalars
555
+ tot_sumsq += (grad**2).sum() * (
556
+ scalar_lr_scale**2
557
+ ) # sum() to change shape [1] to []
558
+ else:
559
+ tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
560
+
561
+ tot_norm = tot_sumsq.sqrt()
562
+ if "model_norms" not in first_state:
563
+ first_state["model_norms"] = torch.zeros(
564
+ clipping_update_period, device=p.device
565
+ )
566
+ first_state["model_norms"][step % clipping_update_period] = tot_norm
567
+
568
+ irregular_estimate_steps = [
569
+ i for i in [10, 20, 40] if i < clipping_update_period
570
+ ]
571
+ if step % clipping_update_period == 0 or step in irregular_estimate_steps:
572
+ # Print some stats.
573
+ # We don't reach here if step == 0 because we would have returned
574
+ # above.
575
+ sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
576
+ if step in irregular_estimate_steps:
577
+ sorted_norms = sorted_norms[-step:]
578
+ num_norms = sorted_norms.numel()
579
+ quartiles = []
580
+ for n in range(0, 5):
581
+ index = min(num_norms - 1, (num_norms // 4) * n)
582
+ quartiles.append(sorted_norms[index].item())
583
+
584
+ median = quartiles[2]
585
+ if median - median != 0:
586
+ raise RuntimeError("Too many grads were not finite")
587
+ threshold = clipping_scale * median
588
+ if step in irregular_estimate_steps:
589
+ # use larger thresholds on first few steps of estimating threshold,
590
+ # as norm may be changing rapidly.
591
+ threshold = threshold * 2.0
592
+ first_state["model_norm_threshold"] = threshold
593
+ percent_clipped = (
594
+ first_state["num_clipped"] * 100.0 / num_norms
595
+ if "num_clipped" in first_state
596
+ else 0.0
597
+ )
598
+ first_state["num_clipped"] = 0
599
+ quartiles = " ".join(["%.3e" % x for x in quartiles])
600
+ logging.warning(
601
+ f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
602
+ f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
603
+ )
604
+
605
+ try:
606
+ model_norm_threshold = first_state["model_norm_threshold"]
607
+ except KeyError:
608
+ return 1.0 # threshold has not yet been set.
609
+
610
+ ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
611
+ if ans != ans: # e.g. ans is nan
612
+ ans = 0.0
613
+ if ans < 1.0:
614
+ first_state["num_clipped"] += 1
615
+ if ans < 0.5:
616
+ logging.warning(
617
+ f"Scaling gradients by {ans}, "
618
+ f"model_norm_threshold={model_norm_threshold}"
619
+ )
620
+ if self.show_dominant_parameters:
621
+ assert p.shape[0] == len(param_names)
622
+ self._show_gradient_dominating_parameter(
623
+ tuples, tot_sumsq, group["scalar_lr_scale"]
624
+ )
625
+ self._show_param_with_unusual_grad(tuples)
626
+
627
+ if ans == 0.0:
628
+ for p, state, param_names in tuples:
629
+ p.grad.zero_() # get rid of infinity()
630
+
631
+ return ans
632
+
633
+ def _show_param_with_unusual_grad(
634
+ self,
635
+ tuples: List[Tuple[Tensor, dict, List[str]]],
636
+ ):
637
+ """
638
+ Print information about parameter which has the largest ratio of
639
+ grad-on-this-batch divided by normal grad size.
640
+ tuples: a list of tuples of (param, state, param_names)
641
+ where param is a batched set of parameters,
642
+ with a .grad (1st dim is batch dim)
643
+ and state is the state-dict where optimization parameters are kept.
644
+ param_names is a List[str] while each str is name for a parameter
645
+ in batched set of parameters "param".
646
+ """
647
+ # ratios_names is a list of 3-tuples: (grad_ratio, param_name, tensor)
648
+ ratios_names = []
649
+ for p, state, batch_param_names in tuples:
650
+ dims = list(range(1, p.ndim))
651
+
652
+ def mean(x):
653
+ # workaround for bad interface of torch's "mean" for when dims is the
654
+ # empty list.
655
+ if len(dims) > 0:
656
+ return x.mean(dim=dims)
657
+ else:
658
+ return x
659
+
660
+ grad_ratio = (
661
+ (mean(p.grad**2) / state["exp_avg_sq"].mean(dim=dims))
662
+ .sqrt()
663
+ .to("cpu")
664
+ )
665
+
666
+ ratios_names += zip(
667
+ grad_ratio.tolist(), batch_param_names, p.grad.unbind(dim=0)
668
+ )
669
+
670
+ ratios_names = sorted(ratios_names, reverse=True)
671
+ ratios_names = ratios_names[:10]
672
+ ratios_names = [
673
+ (ratio, name, largest_index(tensor))
674
+ for (ratio, name, tensor) in ratios_names
675
+ ]
676
+
677
+ logging.debug(
678
+ f"Parameters with most larger-than-usual grads, with ratios, "
679
+ f"are: {ratios_names}"
680
+ )
681
+
682
+ def _show_gradient_dominating_parameter(
683
+ self,
684
+ tuples: List[Tuple[Tensor, dict, List[str]]],
685
+ tot_sumsq: Tensor,
686
+ scalar_lr_scale: float,
687
+ ):
688
+ """
689
+ Show information of parameter which dominates tot_sumsq.
690
+
691
+ Args:
692
+ tuples: a list of tuples of (param, state, param_names)
693
+ where param is a batched set of parameters,
694
+ with a .grad (1st dim is batch dim)
695
+ and state is the state-dict where optimization parameters are kept.
696
+ param_names is a List[str] while each str is name for a parameter
697
+ in batched set of parameters "param".
698
+ tot_sumsq: sumsq of all parameters. Though it's could be calculated
699
+ from tuples, we still pass it to save some time.
700
+ """
701
+ all_sumsq_orig = {}
702
+ for p, state, batch_param_names in tuples:
703
+ # p is a stacked batch parameters.
704
+ batch_grad = p.grad
705
+ if p.numel() == p.shape[0]: # a batch of scalars
706
+ # Dummy values used by following `zip` statement.
707
+ batch_rms_orig = torch.full(
708
+ p.shape, scalar_lr_scale, device=batch_grad.device
709
+ )
710
+ else:
711
+ batch_rms_orig = state["param_rms"]
712
+ batch_sumsq_orig = (batch_grad * batch_rms_orig) ** 2
713
+ if batch_grad.ndim > 1:
714
+ # need to guard it with if-statement because sum() sums over
715
+ # all dims if dim == ().
716
+ batch_sumsq_orig = batch_sumsq_orig.sum(
717
+ dim=list(range(1, batch_grad.ndim))
718
+ )
719
+ for name, sumsq_orig, rms, grad in zip(
720
+ batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
721
+ ):
722
+
723
+ proportion_orig = sumsq_orig / tot_sumsq
724
+ all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
725
+
726
+ sorted_by_proportion = {
727
+ k: v
728
+ for k, v in sorted(
729
+ all_sumsq_orig.items(),
730
+ key=lambda item: item[1][0],
731
+ reverse=True,
732
+ )
733
+ }
734
+ dominant_param_name = next(iter(sorted_by_proportion))
735
+ (
736
+ dominant_proportion,
737
+ dominant_sumsq,
738
+ dominant_rms,
739
+ dominant_grad,
740
+ ) = sorted_by_proportion[dominant_param_name]
741
+ logging.debug(
742
+ f"Parameter dominating tot_sumsq {dominant_param_name}"
743
+ f" with proportion {dominant_proportion:.2f},"
744
+ f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
745
+ f"={dominant_sumsq:.3e},"
746
+ f" grad_sumsq={(dominant_grad**2).sum():.3e},"
747
+ f" orig_rms_sq={(dominant_rms**2).item():.3e}"
748
+ )
749
+
750
+
751
+ def largest_index(x: Tensor):
752
+ x = x.contiguous()
753
+ argmax = x.abs().argmax().item()
754
+ return [(argmax // x.stride(i)) % x.size(i) for i in range(x.ndim)]
755
+
756
+
757
+ def _test_scaled_adam(hidden_dim: int):
758
+ import timeit
759
+
760
+ from zipvoice.models.modules.scaling import ScaledLinear
761
+ from zipvoice.utils.lr_scheduler import Eden
762
+
763
+ E = 100
764
+ B = 4
765
+ T = 2
766
+ logging.info("in test_eve_cain")
767
+ # device = torch.device('cuda')
768
+ device = torch.device("cpu")
769
+ dtype = torch.float32
770
+
771
+ fix_random_seed(42)
772
+ # these input_magnitudes and output_magnitudes are to test that
773
+ # Abel is working as we expect and is able to adjust scales of
774
+ # different dims differently.
775
+ input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
776
+ output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
777
+
778
+ fix_random_seed(42)
779
+ Linear = ScaledLinear
780
+
781
+ m = torch.nn.Sequential(
782
+ Linear(E, hidden_dim),
783
+ torch.nn.PReLU(),
784
+ Linear(hidden_dim, hidden_dim),
785
+ torch.nn.PReLU(),
786
+ Linear(hidden_dim, E),
787
+ ).to(device)
788
+
789
+ train_pairs = [
790
+ (
791
+ 100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes,
792
+ torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes,
793
+ )
794
+ for _ in range(20)
795
+ ]
796
+ optim = ScaledAdam(m.named_parameters(), lr=0.03, clipping_scale=2.0)
797
+ scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
798
+
799
+ start = timeit.default_timer()
800
+ avg_loss = 0.0
801
+ for epoch in range(180):
802
+ scheduler.step_epoch()
803
+ # if epoch == 100 and iter in [2,3]:
804
+ # optim.reset_speedup() # check it doesn't crash.
805
+
806
+ # if epoch == 130:
807
+ # opts = diagnostics.TensorDiagnosticOptions(
808
+ # 512
809
+ # ) # allow 4 megabytes per sub-module
810
+ # diagnostic = diagnostics.attach_diagnostics(m, opts)
811
+
812
+ for n, (x, y) in enumerate(train_pairs):
813
+ y_out = m(x)
814
+ loss = ((y_out - y) ** 2).mean() * 100.0
815
+ if epoch == 0 and n == 0:
816
+ avg_loss = loss.item()
817
+ else:
818
+ avg_loss = 0.98 * avg_loss + 0.02 * loss.item()
819
+ if n == 0 and epoch % 5 == 0:
820
+ # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
821
+ # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
822
+ # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
823
+ # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item()
824
+ # scale1 = '%.2e' % (m[0].weight_scale.exp().item())
825
+ # scale1b = '%.2e' % (m[0].bias_scale.exp().item())
826
+ # scale2 = '%.2e' % (m[2].weight_scale.exp().item())
827
+ # scale2b = '%.2e' % (m[2].bias_scale.exp().item())
828
+ lr = scheduler.get_last_lr()[0]
829
+ logging.info(
830
+ f"Iter {iter}, epoch {epoch}, batch {n}, "
831
+ f"avg_loss {avg_loss:.4g}, lr={lr:.4e}"
832
+ ) # , norms={norm1,norm1b,norm2,norm2b}")
833
+ # scales={scale1,scale1b,scale2,scale2b}
834
+ loss.log().backward()
835
+ optim.step()
836
+ optim.zero_grad()
837
+ scheduler.step_batch()
838
+
839
+ # diagnostic.print_diagnostics()
840
+
841
+ stop = timeit.default_timer()
842
+ logging.info(f"Iter={iter}, Time taken: {stop - start}")
843
+
844
+ logging.info(f"last lr = {scheduler.get_last_lr()}")
845
+ # logging.info("state dict = ", scheduler.state_dict())
846
+ # logging.info("optim state_dict = ", optim.state_dict())
847
+ logging.info(f"input_magnitudes = {input_magnitudes}")
848
+ logging.info(f"output_magnitudes = {output_magnitudes}")
849
+
850
+
851
+ if __name__ == "__main__":
852
+ torch.set_num_threads(1)
853
+ torch.set_num_interop_threads(1)
854
+ logging.getLogger().setLevel(logging.INFO)
855
+ import subprocess
856
+
857
+ s = subprocess.check_output(
858
+ "git status -uno .; git log -1; git diff HEAD .", shell=True
859
+ )
860
+ logging.info(s)
861
+ import sys
862
+
863
+ if len(sys.argv) > 1:
864
+ hidden_dim = int(sys.argv[1])
865
+ else:
866
+ hidden_dim = 200
867
+
868
+ _test_scaled_adam(hidden_dim)
zipvoice/utils/scaling_converter.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
2
+ # Zengwei Yao)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ This file replaces various modules in a model.
20
+ Specifically, ActivationBalancer is replaced with an identity operator;
21
+ Whiten is also replaced with an identity operator;
22
+ BasicNorm is replaced by a module with `exp` removed.
23
+ """
24
+
25
+ import copy
26
+ from typing import List
27
+
28
+ import torch
29
+ import torch.nn as nn
30
+
31
+ from zipvoice.models.modules.scaling import (
32
+ Balancer,
33
+ Dropout3,
34
+ SwooshL,
35
+ SwooshLOnnx,
36
+ SwooshR,
37
+ SwooshROnnx,
38
+ Whiten,
39
+ )
40
+ from zipvoice.models.modules.zipformer import CompactRelPositionalEncoding
41
+
42
+
43
+ # Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
44
+ # get_submodule was added to nn.Module at v1.9.0
45
+ def get_submodule(model, target):
46
+ if target == "":
47
+ return model
48
+ atoms: List[str] = target.split(".")
49
+ mod: torch.nn.Module = model
50
+ for item in atoms:
51
+ if not hasattr(mod, item):
52
+ raise AttributeError(
53
+ mod._get_name() + " has no " "attribute `" + item + "`"
54
+ )
55
+ mod = getattr(mod, item)
56
+ if not isinstance(mod, torch.nn.Module):
57
+ raise AttributeError("`" + item + "` is not " "an nn.Module")
58
+ return mod
59
+
60
+
61
+ def convert_scaled_to_non_scaled(
62
+ model: nn.Module,
63
+ inplace: bool = False,
64
+ is_pnnx: bool = False,
65
+ is_onnx: bool = False,
66
+ ):
67
+ """
68
+ Args:
69
+ model:
70
+ The model to be converted.
71
+ inplace:
72
+ If True, the input model is modified inplace.
73
+ If False, the input model is copied and we modify the copied version.
74
+ is_pnnx:
75
+ True if we are going to export the model for PNNX.
76
+ is_onnx:
77
+ True if we are going to export the model for ONNX.
78
+ Return:
79
+ Return a model without scaled layers.
80
+ """
81
+ if not inplace:
82
+ model = copy.deepcopy(model)
83
+
84
+ d = {}
85
+ for name, m in model.named_modules():
86
+ if isinstance(m, (Balancer, Dropout3, Whiten)):
87
+ d[name] = nn.Identity()
88
+ elif is_onnx and isinstance(m, SwooshR):
89
+ d[name] = SwooshROnnx()
90
+ elif is_onnx and isinstance(m, SwooshL):
91
+ d[name] = SwooshLOnnx()
92
+ elif is_onnx and isinstance(m, CompactRelPositionalEncoding):
93
+ # We want to recreate the positional encoding vector when
94
+ # the input changes, so we have to use torch.jit.script()
95
+ # to replace torch.jit.trace()
96
+ d[name] = torch.jit.script(m)
97
+
98
+ for k, v in d.items():
99
+ if "." in k:
100
+ parent, child = k.rsplit(".", maxsplit=1)
101
+ setattr(get_submodule(model, parent), child, v)
102
+ else:
103
+ setattr(model, k, v)
104
+
105
+ return model