diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..a6344aac8c09253b3b630fb776ae94478aa0275b
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,35 @@
+*.7z filter=lfs diff=lfs merge=lfs -text
+*.arrow filter=lfs diff=lfs merge=lfs -text
+*.bin filter=lfs diff=lfs merge=lfs -text
+*.bz2 filter=lfs diff=lfs merge=lfs -text
+*.ckpt filter=lfs diff=lfs merge=lfs -text
+*.ftz filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
+*.h5 filter=lfs diff=lfs merge=lfs -text
+*.joblib filter=lfs diff=lfs merge=lfs -text
+*.lfs.* filter=lfs diff=lfs merge=lfs -text
+*.mlmodel filter=lfs diff=lfs merge=lfs -text
+*.model filter=lfs diff=lfs merge=lfs -text
+*.msgpack filter=lfs diff=lfs merge=lfs -text
+*.npy filter=lfs diff=lfs merge=lfs -text
+*.npz filter=lfs diff=lfs merge=lfs -text
+*.onnx filter=lfs diff=lfs merge=lfs -text
+*.ot filter=lfs diff=lfs merge=lfs -text
+*.parquet filter=lfs diff=lfs merge=lfs -text
+*.pb filter=lfs diff=lfs merge=lfs -text
+*.pickle filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.pth filter=lfs diff=lfs merge=lfs -text
+*.rar filter=lfs diff=lfs merge=lfs -text
+*.safetensors filter=lfs diff=lfs merge=lfs -text
+saved_model/**/* filter=lfs diff=lfs merge=lfs -text
+*.tar.* filter=lfs diff=lfs merge=lfs -text
+*.tar filter=lfs diff=lfs merge=lfs -text
+*.tflite filter=lfs diff=lfs merge=lfs -text
+*.tgz filter=lfs diff=lfs merge=lfs -text
+*.wasm filter=lfs diff=lfs merge=lfs -text
+*.xz filter=lfs diff=lfs merge=lfs -text
+*.zip filter=lfs diff=lfs merge=lfs -text
+*.zst filter=lfs diff=lfs merge=lfs -text
+*tfevents* filter=lfs diff=lfs merge=lfs -text
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..2d87eedfbcc16d6fd1c4af49e062079984fbe519
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2023 Songting
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..9859ef8c84627abc0c4b8f19a3d7b96163c8af01
--- /dev/null
+++ b/README.md
@@ -0,0 +1,14 @@
+---
+title: VALL E X
+emoji: 🎙
+colorFrom: green
+colorTo: purple
+sdk: gradio
+sdk_version: 3.39.0
+app_file: app.py
+pinned: false
+license: mit
+duplicated_from: Plachta/VALL-E-X
+---
+
+Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ee023a9375d7142795a0f3d8037dacca54efdbe
--- /dev/null
+++ b/__init__.py
@@ -0,0 +1 @@
+from . import data, models, modules, utils
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..67e1913875ccf96c4e329411fab8614ab67d58bd
--- /dev/null
+++ b/app.py
@@ -0,0 +1,582 @@
+import logging
+import os
+import pathlib
+import time
+import tempfile
+import platform
+import gc
+if platform.system().lower() == 'windows':
+ temp = pathlib.PosixPath
+ pathlib.PosixPath = pathlib.WindowsPath
+elif platform.system().lower() == 'linux':
+ temp = pathlib.WindowsPath
+ pathlib.WindowsPath = pathlib.PosixPath
+os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
+
+import langid
+langid.set_languages(['en', 'zh', 'ja'])
+
+import torch
+import torchaudio
+
+import numpy as np
+
+from data.tokenizer import (
+ AudioTokenizer,
+ tokenize_audio,
+)
+from data.collation import get_text_token_collater
+from models.vallex import VALLE
+from utils.g2p import PhonemeBpeTokenizer
+from descriptions import *
+from macros import *
+from examples import *
+
+import gradio as gr
+from vocos import Vocos
+from transformers import WhisperProcessor, WhisperForConditionalGeneration
+
+
+
+torch._C._jit_set_profiling_executor(False)
+torch._C._jit_set_profiling_mode(False)
+torch._C._set_graph_executor_optimize(False)
+
+text_tokenizer = PhonemeBpeTokenizer(tokenizer_path="./utils/g2p/bpe_69.json")
+text_collater = get_text_token_collater()
+
+device = torch.device("cpu")
+if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+# VALL-E-X model
+model = VALLE(
+ N_DIM,
+ NUM_HEAD,
+ NUM_LAYERS,
+ norm_first=True,
+ add_prenet=False,
+ prefix_mode=PREFIX_MODE,
+ share_embedding=True,
+ nar_scale_factor=1.0,
+ prepend_bos=True,
+ num_quantizers=NUM_QUANTIZERS,
+ ).to(device)
+checkpoint = torch.load("./epoch-10.pt", map_location='cpu')
+missing_keys, unexpected_keys = model.load_state_dict(
+ checkpoint["model"], strict=True
+)
+del checkpoint
+assert not missing_keys
+model.eval()
+
+# Encodec model
+audio_tokenizer = AudioTokenizer(device)
+
+# Vocos decoder
+vocos = Vocos.from_pretrained('charactr/vocos-encodec-24khz').to(device)
+
+# ASR
+whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
+whisper = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium").to(device)
+whisper.config.forced_decoder_ids = None
+
+# Voice Presets
+preset_list = os.walk("./presets/").__next__()[2]
+preset_list = [preset[:-4] for preset in preset_list if preset.endswith(".npz")]
+
+def clear_prompts():
+ try:
+ path = tempfile.gettempdir()
+ for eachfile in os.listdir(path):
+ filename = os.path.join(path, eachfile)
+ if os.path.isfile(filename) and filename.endswith(".npz"):
+ lastmodifytime = os.stat(filename).st_mtime
+ endfiletime = time.time() - 60
+ if endfiletime > lastmodifytime:
+ os.remove(filename)
+ del path, filename, lastmodifytime, endfiletime
+ gc.collect()
+ except:
+ return
+
+def transcribe_one(wav, sr):
+ if sr != 16000:
+ wav4trans = torchaudio.transforms.Resample(sr, 16000)(wav)
+ else:
+ wav4trans = wav
+
+ input_features = whisper_processor(wav4trans.squeeze(0), sampling_rate=16000, return_tensors="pt").input_features
+
+ # generate token ids
+ predicted_ids = whisper.generate(input_features.to(device))
+ lang = whisper_processor.batch_decode(predicted_ids[:, 1])[0].strip("<|>")
+ # decode token ids to text
+ text_pr = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
+
+ # print the recognized text
+ print(text_pr)
+
+ if text_pr.strip(" ")[-1] not in "?!.,。,?!。、":
+ text_pr += "."
+
+ # delete all variables
+ del wav4trans, input_features, predicted_ids
+ gc.collect()
+ return lang, text_pr
+
+def make_npz_prompt(name, uploaded_audio, recorded_audio, transcript_content):
+ clear_prompts()
+ audio_prompt = uploaded_audio if uploaded_audio is not None else recorded_audio
+ sr, wav_pr = audio_prompt
+ if len(wav_pr) / sr > 15:
+ return "Rejected, Audio too long (should be less than 15 seconds)", None
+ if not isinstance(wav_pr, torch.FloatTensor):
+ wav_pr = torch.FloatTensor(wav_pr)
+ if wav_pr.abs().max() > 1:
+ wav_pr /= wav_pr.abs().max()
+ if wav_pr.size(-1) == 2:
+ wav_pr = wav_pr[:, 0]
+ if wav_pr.ndim == 1:
+ wav_pr = wav_pr.unsqueeze(0)
+ assert wav_pr.ndim and wav_pr.size(0) == 1
+
+ if transcript_content == "":
+ lang_pr, text_pr = transcribe_one(wav_pr, sr)
+ lang_token = lang2token[lang_pr]
+ text_pr = lang_token + text_pr + lang_token
+ else:
+ lang_pr = langid.classify(str(transcript_content))[0]
+ lang_token = lang2token[lang_pr]
+ transcript_content = transcript_content.replace("\n", "")
+ text_pr = f"{lang_token}{str(transcript_content)}{lang_token}"
+ # tokenize audio
+ encoded_frames = tokenize_audio(audio_tokenizer, (wav_pr, sr))
+ audio_tokens = encoded_frames[0][0].transpose(2, 1).cpu().numpy()
+
+ # tokenize text
+ phonemes, _ = text_tokenizer.tokenize(text=f"{text_pr}".strip())
+ text_tokens, enroll_x_lens = text_collater(
+ [
+ phonemes
+ ]
+ )
+
+ message = f"Detected language: {lang_pr}\n Detected text {text_pr}\n"
+ if lang_pr not in ['ja', 'zh', 'en']:
+ return f"Prompt can only made with one of model-supported languages, got {lang_pr} instead", None
+
+ # save as npz file
+ np.savez(os.path.join(tempfile.gettempdir(), f"{name}.npz"),
+ audio_tokens=audio_tokens, text_tokens=text_tokens, lang_code=lang2code[lang_pr])
+
+ # delete all variables
+ del audio_tokens, text_tokens, phonemes, lang_pr, text_pr, wav_pr, sr, uploaded_audio, recorded_audio
+ gc.collect()
+ return message, os.path.join(tempfile.gettempdir(), f"{name}.npz")
+
+
+@torch.no_grad()
+def infer_from_audio(text, language, accent, audio_prompt, record_audio_prompt, transcript_content):
+ if len(text) > 150:
+ return "Rejected, Text too long (should be less than 150 characters)", None
+ if audio_prompt is None and record_audio_prompt is None:
+ audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device)
+ text_prompts = torch.zeros([1, 0]).type(torch.int32)
+ lang_pr = 'en'
+ text_pr = ""
+ enroll_x_lens = 0
+ wav_pr, sr = None, None
+ else:
+ audio_prompt = audio_prompt if audio_prompt is not None else record_audio_prompt
+ sr, wav_pr = audio_prompt
+ if len(wav_pr) / sr > 15:
+ return "Rejected, Audio too long (should be less than 15 seconds)", None
+ if not isinstance(wav_pr, torch.FloatTensor):
+ wav_pr = torch.FloatTensor(wav_pr)
+ if wav_pr.abs().max() > 1:
+ wav_pr /= wav_pr.abs().max()
+ if wav_pr.size(-1) == 2:
+ wav_pr = wav_pr[:, 0]
+ if wav_pr.ndim == 1:
+ wav_pr = wav_pr.unsqueeze(0)
+ assert wav_pr.ndim and wav_pr.size(0) == 1
+
+ if transcript_content == "":
+ lang_pr, text_pr = transcribe_one(wav_pr, sr)
+ lang_token = lang2token[lang_pr]
+ text_pr = lang_token + text_pr + lang_token
+ else:
+ lang_pr = langid.classify(str(transcript_content))[0]
+ text_pr = transcript_content.replace("\n", "")
+ if lang_pr not in ['ja', 'zh', 'en']:
+ return f"Reference audio must be a speech of one of model-supported languages, got {lang_pr} instead", None
+ lang_token = lang2token[lang_pr]
+ text_pr = lang_token + text_pr + lang_token
+
+ # tokenize audio
+ encoded_frames = tokenize_audio(audio_tokenizer, (wav_pr, sr))
+ audio_prompts = encoded_frames[0][0].transpose(2, 1).to(device)
+
+ enroll_x_lens = None
+ if text_pr:
+ text_prompts, _ = text_tokenizer.tokenize(text=f"{text_pr}".strip())
+ text_prompts, enroll_x_lens = text_collater(
+ [
+ text_prompts
+ ]
+ )
+
+ if language == 'auto-detect':
+ lang_token = lang2token[langid.classify(text)[0]]
+ else:
+ lang_token = langdropdown2token[language]
+ lang = token2lang[lang_token]
+ text = text.replace("\n", "")
+ text = lang_token + text + lang_token
+
+ # tokenize text
+ logging.info(f"synthesize text: {text}")
+ phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
+ text_tokens, text_tokens_lens = text_collater(
+ [
+ phone_tokens
+ ]
+ )
+
+
+ text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
+ text_tokens_lens += enroll_x_lens
+ lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
+ encoded_frames = model.inference(
+ text_tokens.to(device),
+ text_tokens_lens.to(device),
+ audio_prompts,
+ enroll_x_lens=enroll_x_lens,
+ top_k=-100,
+ temperature=1,
+ prompt_language=lang_pr,
+ text_language=langs if accent == "no-accent" else lang,
+ )
+ # Decode with Vocos
+ frames = encoded_frames.permute(2,0,1)
+ features = vocos.codes_to_features(frames)
+ samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
+
+ message = f"text prompt: {text_pr}\nsythesized text: {text}"
+ # delete all variables
+ del audio_prompts, text_tokens, text_prompts, phone_tokens, encoded_frames, wav_pr, sr, audio_prompt, record_audio_prompt, transcript_content
+ gc.collect()
+ return message, (24000, samples.squeeze(0).cpu().numpy())
+
+@torch.no_grad()
+def infer_from_prompt(text, language, accent, preset_prompt, prompt_file):
+ if len(text) > 150:
+ return "Rejected, Text too long (should be less than 150 characters)", None
+ clear_prompts()
+ # text to synthesize
+ if language == 'auto-detect':
+ lang_token = lang2token[langid.classify(text)[0]]
+ else:
+ lang_token = langdropdown2token[language]
+ lang = token2lang[lang_token]
+ text = text.replace("\n", "")
+ text = lang_token + text + lang_token
+
+ # load prompt
+ if prompt_file is not None:
+ prompt_data = np.load(prompt_file.name)
+ else:
+ prompt_data = np.load(os.path.join("./presets/", f"{preset_prompt}.npz"))
+ audio_prompts = prompt_data['audio_tokens']
+ text_prompts = prompt_data['text_tokens']
+ lang_pr = prompt_data['lang_code']
+ lang_pr = code2lang[int(lang_pr)]
+
+ # numpy to tensor
+ audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device)
+ text_prompts = torch.tensor(text_prompts).type(torch.int32)
+
+ enroll_x_lens = text_prompts.shape[-1]
+ logging.info(f"synthesize text: {text}")
+ phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
+ text_tokens, text_tokens_lens = text_collater(
+ [
+ phone_tokens
+ ]
+ )
+ text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
+ text_tokens_lens += enroll_x_lens
+ # accent control
+ lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
+ encoded_frames = model.inference(
+ text_tokens.to(device),
+ text_tokens_lens.to(device),
+ audio_prompts,
+ enroll_x_lens=enroll_x_lens,
+ top_k=-100,
+ temperature=1,
+ prompt_language=lang_pr,
+ text_language=langs if accent == "no-accent" else lang,
+ )
+ # Decode with Vocos
+ frames = encoded_frames.permute(2,0,1)
+ features = vocos.codes_to_features(frames)
+ samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
+
+ message = f"sythesized text: {text}"
+
+ # delete all variables
+ del audio_prompts, text_tokens, text_prompts, phone_tokens, encoded_frames, prompt_file, preset_prompt
+ gc.collect()
+ return message, (24000, samples.squeeze(0).cpu().numpy())
+
+
+from utils.sentence_cutter import split_text_into_sentences
+@torch.no_grad()
+def infer_long_text(text, preset_prompt, prompt=None, language='auto', accent='no-accent'):
+ """
+ For long audio generation, two modes are available.
+ fixed-prompt: This mode will keep using the same prompt the user has provided, and generate audio sentence by sentence.
+ sliding-window: This mode will use the last sentence as the prompt for the next sentence, but has some concern on speaker maintenance.
+ """
+ if len(text) > 1000:
+ return "Rejected, Text too long (should be less than 1000 characters)", None
+ mode = 'fixed-prompt'
+ if (prompt is None or prompt == "") and preset_prompt == "":
+ mode = 'sliding-window' # If no prompt is given, use sliding-window mode
+ sentences = split_text_into_sentences(text)
+ # detect language
+ if language == "auto-detect":
+ language = langid.classify(text)[0]
+ else:
+ language = token2lang[langdropdown2token[language]]
+
+ # if initial prompt is given, encode it
+ if prompt is not None and prompt != "":
+ # load prompt
+ prompt_data = np.load(prompt.name)
+ audio_prompts = prompt_data['audio_tokens']
+ text_prompts = prompt_data['text_tokens']
+ lang_pr = prompt_data['lang_code']
+ lang_pr = code2lang[int(lang_pr)]
+
+ # numpy to tensor
+ audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device)
+ text_prompts = torch.tensor(text_prompts).type(torch.int32)
+ elif preset_prompt is not None and preset_prompt != "":
+ prompt_data = np.load(os.path.join("./presets/", f"{preset_prompt}.npz"))
+ audio_prompts = prompt_data['audio_tokens']
+ text_prompts = prompt_data['text_tokens']
+ lang_pr = prompt_data['lang_code']
+ lang_pr = code2lang[int(lang_pr)]
+
+ # numpy to tensor
+ audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device)
+ text_prompts = torch.tensor(text_prompts).type(torch.int32)
+ else:
+ audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device)
+ text_prompts = torch.zeros([1, 0]).type(torch.int32)
+ lang_pr = language if language != 'mix' else 'en'
+ if mode == 'fixed-prompt':
+ complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device)
+ for text in sentences:
+ text = text.replace("\n", "").strip(" ")
+ if text == "":
+ continue
+ lang_token = lang2token[language]
+ lang = token2lang[lang_token]
+ text = lang_token + text + lang_token
+
+ enroll_x_lens = text_prompts.shape[-1]
+ logging.info(f"synthesize text: {text}")
+ phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
+ text_tokens, text_tokens_lens = text_collater(
+ [
+ phone_tokens
+ ]
+ )
+ text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
+ text_tokens_lens += enroll_x_lens
+ # accent control
+ lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
+ encoded_frames = model.inference(
+ text_tokens.to(device),
+ text_tokens_lens.to(device),
+ audio_prompts,
+ enroll_x_lens=enroll_x_lens,
+ top_k=-100,
+ temperature=1,
+ prompt_language=lang_pr,
+ text_language=langs if accent == "no-accent" else lang,
+ )
+ complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1)
+ # Decode with Vocos
+ frames = complete_tokens.permute(1, 0, 2)
+ features = vocos.codes_to_features(frames)
+ samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
+
+ message = f"Cut into {len(sentences)} sentences"
+ return message, (24000, samples.squeeze(0).cpu().numpy())
+ elif mode == "sliding-window":
+ complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device)
+ original_audio_prompts = audio_prompts
+ original_text_prompts = text_prompts
+ for text in sentences:
+ text = text.replace("\n", "").strip(" ")
+ if text == "":
+ continue
+ lang_token = lang2token[language]
+ lang = token2lang[lang_token]
+ text = lang_token + text + lang_token
+
+ enroll_x_lens = text_prompts.shape[-1]
+ logging.info(f"synthesize text: {text}")
+ phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
+ text_tokens, text_tokens_lens = text_collater(
+ [
+ phone_tokens
+ ]
+ )
+ text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
+ text_tokens_lens += enroll_x_lens
+ # accent control
+ lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
+ encoded_frames = model.inference(
+ text_tokens.to(device),
+ text_tokens_lens.to(device),
+ audio_prompts,
+ enroll_x_lens=enroll_x_lens,
+ top_k=-100,
+ temperature=1,
+ prompt_language=lang_pr,
+ text_language=langs if accent == "no-accent" else lang,
+ )
+ complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1)
+ if torch.rand(1) < 1.0:
+ audio_prompts = encoded_frames[:, :, -NUM_QUANTIZERS:]
+ text_prompts = text_tokens[:, enroll_x_lens:]
+ else:
+ audio_prompts = original_audio_prompts
+ text_prompts = original_text_prompts
+ # Decode with Vocos
+ frames = complete_tokens.permute(1, 0, 2)
+ features = vocos.codes_to_features(frames)
+ samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
+
+ message = f"Cut into {len(sentences)} sentences"
+
+ return message, (24000, samples.squeeze(0).cpu().numpy())
+ else:
+ raise ValueError(f"No such mode {mode}")
+
+app = gr.Blocks()
+with app:
+ gr.Markdown(top_md)
+ with gr.Tab("Infer from audio"):
+ gr.Markdown(infer_from_audio_md)
+ with gr.Row():
+ with gr.Column():
+
+ textbox = gr.TextArea(label="Text",
+ placeholder="Type your sentence here",
+ value="Welcome back, Master. What can I do for you today?", elem_id=f"tts-input")
+ language_dropdown = gr.Dropdown(choices=['auto-detect', 'English', '中文', '日本語'], value='auto-detect', label='language')
+ accent_dropdown = gr.Dropdown(choices=['no-accent', 'English', '中文', '日本語'], value='no-accent', label='accent')
+ textbox_transcript = gr.TextArea(label="Transcript",
+ placeholder="Write transcript here. (leave empty to use whisper)",
+ value="", elem_id=f"prompt-name")
+ upload_audio_prompt = gr.Audio(label='uploaded audio prompt', source='upload', interactive=True)
+ record_audio_prompt = gr.Audio(label='recorded audio prompt', source='microphone', interactive=True)
+ with gr.Column():
+ text_output = gr.Textbox(label="Message")
+ audio_output = gr.Audio(label="Output Audio", elem_id="tts-audio")
+ btn = gr.Button("Generate!")
+ btn.click(infer_from_audio,
+ inputs=[textbox, language_dropdown, accent_dropdown, upload_audio_prompt, record_audio_prompt, textbox_transcript],
+ outputs=[text_output, audio_output])
+ textbox_mp = gr.TextArea(label="Prompt name",
+ placeholder="Name your prompt here",
+ value="prompt_1", elem_id=f"prompt-name")
+ btn_mp = gr.Button("Make prompt!")
+ prompt_output = gr.File(interactive=False)
+ btn_mp.click(make_npz_prompt,
+ inputs=[textbox_mp, upload_audio_prompt, record_audio_prompt, textbox_transcript],
+ outputs=[text_output, prompt_output])
+ gr.Examples(examples=infer_from_audio_examples,
+ inputs=[textbox, language_dropdown, accent_dropdown, upload_audio_prompt, record_audio_prompt, textbox_transcript],
+ outputs=[text_output, audio_output],
+ fn=infer_from_audio,
+ cache_examples=False,)
+ with gr.Tab("Make prompt"):
+ gr.Markdown(make_prompt_md)
+ with gr.Row():
+ with gr.Column():
+ textbox2 = gr.TextArea(label="Prompt name",
+ placeholder="Name your prompt here",
+ value="prompt_1", elem_id=f"prompt-name")
+ # 添加选择语言和输入台本的地方
+ textbox_transcript2 = gr.TextArea(label="Transcript",
+ placeholder="Write transcript here. (leave empty to use whisper)",
+ value="", elem_id=f"prompt-name")
+ upload_audio_prompt_2 = gr.Audio(label='uploaded audio prompt', source='upload', interactive=True)
+ record_audio_prompt_2 = gr.Audio(label='recorded audio prompt', source='microphone', interactive=True)
+ with gr.Column():
+ text_output_2 = gr.Textbox(label="Message")
+ prompt_output_2 = gr.File(interactive=False)
+ btn_2 = gr.Button("Make!")
+ btn_2.click(make_npz_prompt,
+ inputs=[textbox2, upload_audio_prompt_2, record_audio_prompt_2, textbox_transcript2],
+ outputs=[text_output_2, prompt_output_2])
+ gr.Examples(examples=make_npz_prompt_examples,
+ inputs=[textbox2, upload_audio_prompt_2, record_audio_prompt_2, textbox_transcript2],
+ outputs=[text_output_2, prompt_output_2],
+ fn=make_npz_prompt,
+ cache_examples=False,)
+ with gr.Tab("Infer from prompt"):
+ gr.Markdown(infer_from_prompt_md)
+ with gr.Row():
+ with gr.Column():
+ textbox_3 = gr.TextArea(label="Text",
+ placeholder="Type your sentence here",
+ value="Welcome back, Master. What can I do for you today?", elem_id=f"tts-input")
+ language_dropdown_3 = gr.Dropdown(choices=['auto-detect', 'English', '中文', '日本語', 'Mix'], value='auto-detect',
+ label='language')
+ accent_dropdown_3 = gr.Dropdown(choices=['no-accent', 'English', '中文', '日本語'], value='no-accent',
+ label='accent')
+ preset_dropdown_3 = gr.Dropdown(choices=preset_list, value=None, label='Voice preset')
+ prompt_file = gr.File(file_count='single', file_types=['.npz'], interactive=True)
+ with gr.Column():
+ text_output_3 = gr.Textbox(label="Message")
+ audio_output_3 = gr.Audio(label="Output Audio", elem_id="tts-audio")
+ btn_3 = gr.Button("Generate!")
+ btn_3.click(infer_from_prompt,
+ inputs=[textbox_3, language_dropdown_3, accent_dropdown_3, preset_dropdown_3, prompt_file],
+ outputs=[text_output_3, audio_output_3])
+ gr.Examples(examples=infer_from_prompt_examples,
+ inputs=[textbox_3, language_dropdown_3, accent_dropdown_3, preset_dropdown_3, prompt_file],
+ outputs=[text_output_3, audio_output_3],
+ fn=infer_from_prompt,
+ cache_examples=False,)
+ with gr.Tab("Infer long text"):
+ gr.Markdown(long_text_md)
+ with gr.Row():
+ with gr.Column():
+ textbox_4 = gr.TextArea(label="Text",
+ placeholder="Type your sentence here",
+ value=long_text_example, elem_id=f"tts-input")
+ language_dropdown_4 = gr.Dropdown(choices=['auto-detect', 'English', '中文', '日本語'], value='auto-detect',
+ label='language')
+ accent_dropdown_4 = gr.Dropdown(choices=['no-accent', 'English', '中文', '日本語'], value='no-accent',
+ label='accent')
+ preset_dropdown_4 = gr.Dropdown(choices=preset_list, value=None, label='Voice preset')
+ prompt_file_4 = gr.File(file_count='single', file_types=['.npz'], interactive=True)
+ with gr.Column():
+ text_output_4 = gr.TextArea(label="Message")
+ audio_output_4 = gr.Audio(label="Output Audio", elem_id="tts-audio")
+ btn_4 = gr.Button("Generate!")
+ btn_4.click(infer_long_text,
+ inputs=[textbox_4, preset_dropdown_4, prompt_file_4, language_dropdown_4, accent_dropdown_4],
+ outputs=[text_output_4, audio_output_4])
+
+app.launch()
\ No newline at end of file
diff --git a/data/__init__.py b/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..68f9defe677e03da5224c42cb28932f2e7f75ada
--- /dev/null
+++ b/data/__init__.py
@@ -0,0 +1 @@
+from .collation import *
diff --git a/data/collation.py b/data/collation.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd8665b2afc40f0e5b13a07b92cebe88aecdbf66
--- /dev/null
+++ b/data/collation.py
@@ -0,0 +1,118 @@
+from pathlib import Path
+from typing import List, Tuple
+
+import numpy as np
+import torch
+
+
+class TextTokenCollater:
+ """Collate list of text tokens
+
+ Map sentences to integers. Sentences are padded to equal length.
+ Beginning and end-of-sequence symbols can be added.
+
+ Example:
+ >>> token_collater = TextTokenCollater(text_tokens)
+ >>> tokens_batch, tokens_lens = token_collater(text)
+
+ Returns:
+ tokens_batch: IntTensor of shape (B, L)
+ B: batch dimension, number of input sentences
+ L: length of the longest sentence
+ tokens_lens: IntTensor of shape (B,)
+ Length of each sentence after adding and
+ but before padding.
+ """
+
+ def __init__(
+ self,
+ text_tokens: List[str],
+ add_eos: bool = True,
+ add_bos: bool = True,
+ pad_symbol: str = "",
+ bos_symbol: str = "",
+ eos_symbol: str = "",
+ ):
+ self.pad_symbol = pad_symbol
+
+ self.add_eos = add_eos
+ self.add_bos = add_bos
+
+ self.bos_symbol = bos_symbol
+ self.eos_symbol = eos_symbol
+
+ unique_tokens = (
+ [pad_symbol]
+ + ([bos_symbol] if add_bos else [])
+ + ([eos_symbol] if add_eos else [])
+ + sorted(text_tokens)
+ )
+
+ self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)}
+ self.idx2token = [token for token in unique_tokens]
+
+ def index(
+ self, tokens_list: List[str]
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ seqs, seq_lens = [], []
+ for tokens in tokens_list:
+ assert (
+ all([True if s in self.token2idx else False for s in tokens])
+ is True
+ )
+ seq = (
+ ([self.bos_symbol] if self.add_bos else [])
+ + list(tokens)
+ + ([self.eos_symbol] if self.add_eos else [])
+ )
+ seqs.append(seq)
+ seq_lens.append(len(seq))
+
+ max_len = max(seq_lens)
+ for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)):
+ seq.extend([self.pad_symbol] * (max_len - seq_len))
+
+ tokens = torch.from_numpy(
+ np.array(
+ [[self.token2idx[token] for token in seq] for seq in seqs],
+ dtype=np.int64,
+ )
+ )
+ tokens_lens = torch.IntTensor(seq_lens)
+
+ return tokens, tokens_lens
+
+ def __call__(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
+ tokens_seqs = [[p for p in text] for text in texts]
+ max_len = len(max(tokens_seqs, key=len))
+
+ seqs = [
+ ([self.bos_symbol] if self.add_bos else [])
+ + list(seq)
+ + ([self.eos_symbol] if self.add_eos else [])
+ + [self.pad_symbol] * (max_len - len(seq))
+ for seq in tokens_seqs
+ ]
+
+ tokens_batch = torch.from_numpy(
+ np.array(
+ [seq for seq in seqs],
+ dtype=np.int64,
+ )
+ )
+
+ tokens_lens = torch.IntTensor(
+ [
+ len(seq) + int(self.add_eos) + int(self.add_bos)
+ for seq in tokens_seqs
+ ]
+ )
+
+ return tokens_batch, tokens_lens
+
+
+def get_text_token_collater() -> TextTokenCollater:
+ collater = TextTokenCollater(
+ ['0'], add_bos=False, add_eos=False
+ )
+ return collater
diff --git a/data/tokenizer.py b/data/tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b8e889641b6715b9b4fa3cffd3dd7bef06ad7e9
--- /dev/null
+++ b/data/tokenizer.py
@@ -0,0 +1,117 @@
+#!/usr/bin/env python3
+# Copyright 2023 (authors: Feiteng Li)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+from dataclasses import asdict, dataclass
+from typing import Any, Dict, List, Optional, Pattern, Union
+
+import numpy as np
+import torch
+import torchaudio
+from encodec import EncodecModel
+from encodec.utils import convert_audio
+
+def remove_encodec_weight_norm(model):
+ from encodec.modules import SConv1d
+ from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock
+ from torch.nn.utils import remove_weight_norm
+
+ encoder = model.encoder.model
+ for key in encoder._modules:
+ if isinstance(encoder._modules[key], SEANetResnetBlock):
+ remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
+ block_modules = encoder._modules[key].block._modules
+ for skey in block_modules:
+ if isinstance(block_modules[skey], SConv1d):
+ remove_weight_norm(block_modules[skey].conv.conv)
+ elif isinstance(encoder._modules[key], SConv1d):
+ remove_weight_norm(encoder._modules[key].conv.conv)
+
+ decoder = model.decoder.model
+ for key in decoder._modules:
+ if isinstance(decoder._modules[key], SEANetResnetBlock):
+ remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
+ block_modules = decoder._modules[key].block._modules
+ for skey in block_modules:
+ if isinstance(block_modules[skey], SConv1d):
+ remove_weight_norm(block_modules[skey].conv.conv)
+ elif isinstance(decoder._modules[key], SConvTranspose1d):
+ remove_weight_norm(decoder._modules[key].convtr.convtr)
+ elif isinstance(decoder._modules[key], SConv1d):
+ remove_weight_norm(decoder._modules[key].conv.conv)
+
+
+class AudioTokenizer:
+ """EnCodec audio."""
+
+ def __init__(
+ self,
+ device: Any = None,
+ ) -> None:
+ # Instantiate a pretrained EnCodec model
+ model = EncodecModel.encodec_model_24khz()
+ model.set_target_bandwidth(6.0)
+ remove_encodec_weight_norm(model)
+
+ if not device:
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda:0")
+
+ self._device = device
+
+ self.codec = model.to(device)
+ self.sample_rate = model.sample_rate
+ self.channels = model.channels
+
+ @property
+ def device(self):
+ return self._device
+
+ def encode(self, wav: torch.Tensor) -> torch.Tensor:
+ return self.codec.encode(wav.to(self.device))
+
+ def decode(self, frames: torch.Tensor) -> torch.Tensor:
+ return self.codec.decode(frames)
+
+
+def tokenize_audio(tokenizer: AudioTokenizer, audio):
+ # Load and pre-process the audio waveform
+ if isinstance(audio, str):
+ wav, sr = torchaudio.load(audio)
+ else:
+ wav, sr = audio
+ wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
+ wav = wav.unsqueeze(0)
+
+ # Extract discrete codes from EnCodec
+ with torch.no_grad():
+ encoded_frames = tokenizer.encode(wav)
+ return encoded_frames
+
+
+if __name__ == "__main__":
+ model = EncodecModel.encodec_model_24khz()
+ model.set_target_bandwidth(6.0)
+
+ samples = torch.from_numpy(np.random.random([4, 1, 1600])).type(
+ torch.float32
+ )
+ codes_raw = model.encode(samples)
+
+ remove_encodec_weight_norm(model)
+ codes_norm = model.encode(samples)
+
+ assert torch.allclose(codes_raw[0][0], codes_norm[0][0])
diff --git a/descriptions.py b/descriptions.py
new file mode 100644
index 0000000000000000000000000000000000000000..25ae5afca7a4a18acb543ecd008151821451d97d
--- /dev/null
+++ b/descriptions.py
@@ -0,0 +1,38 @@
+top_md = """
+# VALL-E X
+
+
or
+
to skip the queue.
+VALL-E X can synthesize high-quality personalized speech with only a 3-second enrolled recording of
+an unseen speaker as an acoustic prompt, even in another language for a monolingual speaker.
+This implementation supports zero-shot, mono-lingual/cross-lingual text-to-speech functionality of three languages (English, Chinese, Japanese)
+See this [demo](https://plachtaa.github.io/) page for more details.
+"""
+
+infer_from_audio_md = """
+Upload a speech of 3~10 seconds as the audio prompt and type in the text you'd like to synthesize.
+The model will synthesize speech of given text with the same voice of your audio prompt.
+The model also tends to preserve the emotion & acoustic environment of your given speech.
+For faster inference, please use **"Make prompt"** to get a `.npz` file as the encoded audio prompt, and use it by **"Infer from prompt"**
+"""
+
+make_prompt_md = """
+Upload a speech of 3~10 seconds as the audio prompt.
+Get a `.npz` file as the encoded audio prompt. Use it by **"Infer with prompt"**
+"""
+
+infer_from_prompt_md = """
+Faster than **"Infer from audio"**.
+You need to **"Make prompt"** first, and upload the encoded prompt (a `.npz` file)
+"""
+
+long_text_md = """
+Very long text is chunked into several sentences, and each sentence is synthesized separately.
+Please make a prompt or use a preset prompt to infer long text.
+"""
+
+long_text_example = "Just a few years ago, there were no legions of deep learning scientists developing intelligent products and services at major companies and startups. When we entered the field, machine learning did not command headlines in daily newspapers. Our parents had no idea what machine learning was, let alone why we might prefer it to a career in medicine or law. Machine learning was a blue skies academic discipline whose industrial significance was limited to a narrow set of real-world applications, including speech recognition and computer vision. Moreover, many of these applications required so much domain knowledge that they were often regarded as entirely separate areas for which machine learning was one small component. At that time, neural networks—the predecessors of the deep learning methods that we focus on in this book—were generally regarded as outmoded."
\ No newline at end of file
diff --git a/epoch-10.pt b/epoch-10.pt
new file mode 100644
index 0000000000000000000000000000000000000000..9cb6f4c9fd03560d111452eaba307e302306ef7d
--- /dev/null
+++ b/epoch-10.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c5fcd05ee0c9c84a16a7b44495c46262177e66d5d454c20ca5f1da9832dbd5ac
+size 1482302113
diff --git a/examples.py b/examples.py
new file mode 100644
index 0000000000000000000000000000000000000000..205210e0d03f1203648c8fc327da713f9db5eb4e
--- /dev/null
+++ b/examples.py
@@ -0,0 +1,24 @@
+infer_from_audio_examples = [
+ ["This is how this machine has taken my voice.", 'English', 'no-accent', "prompts/en-2.wav", None, "Wow, look at that! That's no ordinary Teddy bear!"],
+ ["我喜欢抽电子烟,尤其是锐刻五代。", '中文', 'no-accent', "prompts/zh-1.wav", None, "今天我很荣幸,"],
+ ["私の声を真似するのはそんなに面白いですか?", '日本語', 'no-accent', "prompts/ja-2.ogg", None, "初めまして、朝武よしのです。"],
+ ["你可以听得出来我有多困。", '中文', 'no-accent', "prompts/en-1.wav", None, ""],
+ ["この文は、クロスリンガル合成の例です。", '日本語', 'no-accent', "prompts/zh-2.wav", None, ""],
+ ["Actually, I can't speak English, but this machine helped me do it.", 'English', 'no-accent', "prompts/ja-1.wav", None, ""],
+]
+
+make_npz_prompt_examples = [
+ ["Gem-trader", "prompts/en-2.wav", None, "Wow, look at that! That's no ordinary Teddy bear!"],
+ ["Ding Zhen", "prompts/zh-1.wav", None, "今天我很荣幸,"],
+ ["Yoshino", "prompts/ja-2.ogg", None, "初めまして、朝武よしのです。"],
+ ["Sleepy-woman", "prompts/en-1.wav", None, ""],
+ ["Yae", "prompts/zh-2.wav", None, ""],
+ ["Cafe", "prompts/ja-1.wav", None, ""],
+]
+
+infer_from_prompt_examples = [
+ ["A prompt contains voice, prosody and emotion information of a certain speaker.", "English", "no-accent", "vctk_1", None],
+ ["This prompt is made with an audio of three seconds.", "English", "no-accent", "librispeech_1", None],
+ ["This prompt is made with Chinese speech", "English", "no-accent", "seel", None],
+]
+
diff --git a/images/vallex_framework.jpg b/images/vallex_framework.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..c8dea042c2e3919003721a6adeb67cb658ed76bc
Binary files /dev/null and b/images/vallex_framework.jpg differ
diff --git a/macros.py b/macros.py
new file mode 100644
index 0000000000000000000000000000000000000000..b192fccde1a11da26cff026c9a08c8ff54915907
--- /dev/null
+++ b/macros.py
@@ -0,0 +1,39 @@
+NUM_LAYERS = 12
+NUM_HEAD = 16
+N_DIM = 1024
+PREFIX_MODE = 1
+NUM_QUANTIZERS = 8
+SAMPLE_RATE = 24000
+
+lang2token = {
+ 'zh': "[ZH]",
+ 'ja': "[JA]",
+ "en": "[EN]",
+ 'mix': "",
+}
+
+lang2code = {
+ 'zh': 0,
+ 'ja': 1,
+ "en": 2,
+}
+
+token2lang = {
+ '[ZH]': "zh",
+ '[JA]': "ja",
+ "[EN]": "en",
+ "": "mix"
+}
+
+code2lang = {
+ 0: 'zh',
+ 1: 'ja',
+ 2: "en",
+}
+
+langdropdown2token = {
+ 'English': "[EN]",
+ '中文': "[ZH]",
+ '日本語': "[JA]",
+ 'Mix': "",
+}
\ No newline at end of file
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3964a73a02c98de656da931b2c3f6121dbad7a28
--- /dev/null
+++ b/models/__init__.py
@@ -0,0 +1,126 @@
+import argparse
+
+import torch.nn as nn
+# from icefall.utils import AttributeDict, str2bool
+
+from .macros import (
+ NUM_AUDIO_TOKENS,
+ NUM_MEL_BINS,
+ NUM_SPEAKER_CLASSES,
+ NUM_TEXT_TOKENS,
+ SPEAKER_EMBEDDING_DIM,
+)
+from .vallex import VALLE, VALLF
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--model-name",
+ type=str,
+ default="VALL-E",
+ help="VALL-E, VALL-F, Transformer.",
+ )
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=1024,
+ help="Embedding dimension in the decoder model.",
+ )
+ parser.add_argument(
+ "--nhead",
+ type=int,
+ default=16,
+ help="Number of attention heads in the Decoder layers.",
+ )
+ parser.add_argument(
+ "--num-decoder-layers",
+ type=int,
+ default=12,
+ help="Number of Decoder layers.",
+ )
+ parser.add_argument(
+ "--scale-factor",
+ type=float,
+ default=1.0,
+ help="Model scale factor which will be assigned different meanings in different models.",
+ )
+ parser.add_argument(
+ "--norm-first",
+ type=bool,
+ default=True,
+ help="Pre or Post Normalization.",
+ )
+ parser.add_argument(
+ "--add-prenet",
+ type=bool,
+ default=False,
+ help="Whether add PreNet after Inputs.",
+ )
+
+ # VALL-E & F
+ parser.add_argument(
+ "--prefix-mode",
+ type=int,
+ default=1,
+ help="The mode for how to prefix VALL-E NAR Decoder, "
+ "0: no prefix, 1: 0 to random, 2: random to random, 4: chunk of pre or post utterance.",
+ )
+ parser.add_argument(
+ "--share-embedding",
+ type=bool,
+ default=True,
+ help="Share the parameters of the output projection layer with the parameters of the acoustic embedding.",
+ )
+ parser.add_argument(
+ "--prepend-bos",
+ type=bool,
+ default=False,
+ help="Whether prepend to the acoustic tokens -> AR Decoder inputs.",
+ )
+ parser.add_argument(
+ "--num-quantizers",
+ type=int,
+ default=8,
+ help="Number of Audio/Semantic quantization layers.",
+ )
+
+ # Transformer
+ parser.add_argument(
+ "--scaling-xformers",
+ type=bool,
+ default=False,
+ help="Apply Reworked Conformer scaling on Transformers.",
+ )
+
+
+def get_model(params) -> nn.Module:
+ if params.model_name.lower() in ["vall-f", "vallf"]:
+ model = VALLF(
+ params.decoder_dim,
+ params.nhead,
+ params.num_decoder_layers,
+ norm_first=params.norm_first,
+ add_prenet=params.add_prenet,
+ prefix_mode=params.prefix_mode,
+ share_embedding=params.share_embedding,
+ nar_scale_factor=params.scale_factor,
+ prepend_bos=params.prepend_bos,
+ num_quantizers=params.num_quantizers,
+ )
+ elif params.model_name.lower() in ["vall-e", "valle"]:
+ model = VALLE(
+ params.decoder_dim,
+ params.nhead,
+ params.num_decoder_layers,
+ norm_first=params.norm_first,
+ add_prenet=params.add_prenet,
+ prefix_mode=params.prefix_mode,
+ share_embedding=params.share_embedding,
+ nar_scale_factor=params.scale_factor,
+ prepend_bos=params.prepend_bos,
+ num_quantizers=params.num_quantizers,
+ )
+ else:
+ raise ValueError("No such model")
+
+ return model
diff --git a/models/macros.py b/models/macros.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbc54966f43b2ef27d87c3b4bc69cb866d2b8fd0
--- /dev/null
+++ b/models/macros.py
@@ -0,0 +1,11 @@
+# Text
+NUM_TEXT_TOKENS = 2048
+
+# Audio
+NUM_AUDIO_TOKENS = 1024 # EnCodec RVQ bins
+NUM_MEL_BINS = 100 # BigVGAN bigvgan_24khz_100band
+
+
+# Speaker
+NUM_SPEAKER_CLASSES = 4096
+SPEAKER_EMBEDDING_DIM = 64
diff --git a/models/vallex.py b/models/vallex.py
new file mode 100644
index 0000000000000000000000000000000000000000..745b72bcff05ecbb082b41691f978ee90e382123
--- /dev/null
+++ b/models/vallex.py
@@ -0,0 +1,851 @@
+# Copyright 2023 (authors: Feiteng Li)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+from typing import Dict, Iterator, List, Tuple, Union
+import gc
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+# from icefall.utils import make_pad_mask
+# from torchmetrics.classification import MulticlassAccuracy
+
+from modules.embedding import SinePositionalEmbedding, TokenEmbedding
+from modules.transformer import (
+ AdaptiveLayerNorm,
+ LayerNorm,
+ TransformerDecoderLayer,
+ TransformerEncoder,
+ TransformerEncoderLayer,
+)
+
+from .macros import NUM_AUDIO_TOKENS, NUM_TEXT_TOKENS
+
+import psutil
+def get_memory_usage():
+ process = psutil.Process()
+ memory_info = process.memory_info()
+
+ memory_used = memory_info.rss
+ memory_used_mb = memory_used / (1024 * 1024)
+
+ return memory_used_mb
+
+class Transpose(nn.Identity):
+ """(N, T, D) -> (N, D, T)"""
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ return input.transpose(1, 2)
+
+
+# NOTE: There are two ways to implement the model
+# 1) [VALL-F] standard TransformerDecoder, use x as memory
+# 2) [VALL-E] modified TransformerDecoder like GPT-x(e.g. causal TransformerEncoder),
+# use x as the prefix of decoder inputs
+class VALLF(nn.Module):
+ """It implements https://arxiv.org/abs/2301.02111
+ "Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers"
+ """
+
+ def __init__(
+ self,
+ d_model: int,
+ nhead: int,
+ num_layers: int,
+ norm_first: bool = True,
+ add_prenet: bool = False,
+ decoder_cls: Union[
+ nn.TransformerDecoder, nn.TransformerEncoder
+ ] = nn.TransformerDecoder,
+ decoder_layer_cls: Union[
+ TransformerDecoderLayer, TransformerEncoderLayer
+ ] = TransformerDecoderLayer,
+ prefix_mode: int = 0,
+ share_embedding: bool = True,
+ nar_scale_factor: float = 1.0,
+ prepend_bos: bool = True,
+ num_quantizers: int = 8,
+ ):
+ """
+ Args:
+ d_model:
+ The number of expected features in the input (required).
+ nhead:
+ The number of heads in the multiheadattention models (required).
+ num_layers:
+ The number of sub-decoder-layers in the decoder (required).
+ """
+ super().__init__()
+ nar_d_model = int(d_model * nar_scale_factor)
+
+ self.ar_text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x
+ self.nar_text_embedding = TokenEmbedding(nar_d_model, NUM_TEXT_TOKENS)
+
+ # ID NUM_AUDIO_TOKENS -> PAD
+ # ID NUM_AUDIO_TOKENS + 1 -> BOS
+ self.ar_audio_prepend_bos = prepend_bos
+ self.ar_audio_embedding = TokenEmbedding(
+ d_model, NUM_AUDIO_TOKENS + 1 + int(prepend_bos)
+ )
+
+ # PreNet
+ if add_prenet:
+ self.ar_text_prenet = nn.Sequential(
+ Transpose(),
+ nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
+ nn.BatchNorm1d(d_model),
+ nn.ReLU(),
+ nn.Dropout(0.5),
+ nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
+ nn.BatchNorm1d(d_model),
+ nn.ReLU(),
+ nn.Dropout(0.5),
+ nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
+ nn.BatchNorm1d(d_model),
+ nn.ReLU(),
+ nn.Dropout(0.5),
+ Transpose(),
+ nn.Linear(d_model, d_model),
+ )
+
+ self.ar_audio_prenet = nn.Sequential(
+ nn.Linear(d_model, 256),
+ nn.ReLU(),
+ nn.Dropout(0.25),
+ nn.Linear(256, 256),
+ nn.ReLU(),
+ nn.Dropout(0.25),
+ nn.Linear(256, d_model),
+ )
+ else:
+ self.ar_text_prenet = nn.Identity()
+ self.ar_audio_prenet = nn.Identity()
+
+ self.ar_text_position = SinePositionalEmbedding(
+ d_model,
+ dropout=0.1,
+ scale=False,
+ alpha=True,
+ )
+ self.ar_audio_position = SinePositionalEmbedding(
+ d_model,
+ dropout=0.1,
+ scale=False,
+ alpha=True,
+ )
+
+ self.ar_decoder = decoder_cls(
+ decoder_layer_cls(
+ d_model,
+ nhead,
+ dim_feedforward=d_model * 4,
+ dropout=0.1,
+ batch_first=True,
+ norm_first=norm_first,
+ ),
+ num_layers=num_layers,
+ norm=LayerNorm(d_model) if norm_first else None,
+ )
+ self.ar_predict_layer = nn.Linear(
+ d_model, NUM_AUDIO_TOKENS + 1, bias=False
+ )
+
+ self.rng = random.Random(0)
+ self.num_heads = nhead
+ self.prefix_mode = prefix_mode
+ self.num_quantizers = num_quantizers
+
+ assert num_quantizers >= 1
+ if num_quantizers > 1:
+ self.nar_audio_embeddings = nn.ModuleList(
+ [TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS + 1)]
+ + [
+ TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS)
+ for i in range(num_quantizers - 1)
+ ]
+ ) # W_a
+
+ # PreNet
+ if add_prenet:
+ self.nar_text_prenet = nn.Sequential(
+ Transpose(),
+ nn.Conv1d(
+ nar_d_model, nar_d_model, kernel_size=5, padding="same"
+ ),
+ nn.BatchNorm1d(nar_d_model),
+ nn.ReLU(),
+ nn.Dropout(0.5),
+ nn.Conv1d(
+ nar_d_model, nar_d_model, kernel_size=5, padding="same"
+ ),
+ nn.BatchNorm1d(nar_d_model),
+ nn.ReLU(),
+ nn.Dropout(0.5),
+ nn.Conv1d(
+ nar_d_model, nar_d_model, kernel_size=5, padding="same"
+ ),
+ nn.BatchNorm1d(nar_d_model),
+ nn.ReLU(),
+ nn.Dropout(0.5),
+ Transpose(),
+ nn.Linear(nar_d_model, nar_d_model),
+ )
+ self.nar_audio_prenet = nn.Sequential(
+ nn.Linear(nar_d_model, 256),
+ nn.ReLU(),
+ nn.Dropout(0.25),
+ nn.Linear(256, 256),
+ nn.ReLU(),
+ nn.Dropout(0.25),
+ nn.Linear(256, nar_d_model),
+ )
+ else:
+ self.nar_text_prenet = nn.Identity()
+ self.nar_audio_prenet = nn.Identity()
+
+ self.nar_text_position = SinePositionalEmbedding(
+ nar_d_model,
+ dropout=0.0,
+ scale=False,
+ alpha=False,
+ )
+ self.nar_audio_position = SinePositionalEmbedding(
+ nar_d_model,
+ dropout=0.1,
+ scale=False,
+ alpha=False,
+ )
+
+ self.nar_decoder = decoder_cls(
+ decoder_layer_cls(
+ nar_d_model,
+ int(nhead * nar_scale_factor),
+ dim_feedforward=nar_d_model * 4,
+ dropout=0.1,
+ batch_first=True,
+ norm_first=norm_first,
+ adaptive_layer_norm=True,
+ ),
+ num_layers=int(num_layers * nar_scale_factor),
+ norm=AdaptiveLayerNorm(
+ nar_d_model, norm=nn.LayerNorm(nar_d_model)
+ )
+ if norm_first
+ else None,
+ )
+ self.nar_predict_layers = nn.ModuleList(
+ [
+ nn.Linear(nar_d_model, NUM_AUDIO_TOKENS, bias=False)
+ for i in range(num_quantizers - 1)
+ ]
+ )
+ self.nar_stage_embeddings = nn.ModuleList(
+ [
+ TokenEmbedding(nar_d_model, 1)
+ for i in range(num_quantizers - 1)
+ ]
+ )
+
+ if share_embedding:
+ # We share the parameters of the output projection layer with the parameters of the acoustic embedding Wa
+ # NOTE(Feiteng): In the experiment, this undermines accuracy
+ # self.ar_predict_layer.weight = self.ar_audio_embedding.weight
+
+ # We also share the parameters of the acoustic embedding layer and the output prediction layer,
+ # which means the weights of the j-th prediction layer are the same as the (j + 1)-th acoustic embedding layer.
+ for j in range(0, num_quantizers - 2):
+ self.nar_predict_layers[
+ j
+ ].weight = self.nar_audio_embeddings[j + 2].weight
+
+ def stage_parameters(self, stage: int = 1) -> Iterator[nn.Parameter]:
+ assert stage > 0
+ if stage == 1:
+ for name, param in self.named_parameters():
+ if name.startswith("ar_"):
+ print(f" AR parameter: {name}")
+ yield param
+
+ if stage == 2:
+ for name, param in self.named_parameters():
+ if name.startswith("nar_"):
+ print(f"NAR parameter: {name}")
+ yield param
+
+ def stage_named_parameters(
+ self, stage: int = 1
+ ) -> Iterator[Tuple[str, nn.Parameter]]:
+ assert stage > 0
+ if stage == 1:
+ for pair in self.named_parameters():
+ if pair[0].startswith("ar_"):
+ yield pair
+
+ if stage == 2:
+ for pair in self.named_parameters():
+ if pair[0].startswith("nar_"):
+ yield pair
+
+ def pad_y_eos(self, y, y_mask_int, eos_id):
+ targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
+ y_mask_int, (0, 1), value=1
+ )
+ # inputs, targets
+ if self.ar_audio_prepend_bos:
+ return (
+ F.pad(targets[:, :-1], (1, 0), value=NUM_AUDIO_TOKENS + 1),
+ targets,
+ )
+
+ return targets[:, :-1], targets[:, 1:]
+
+ def _prepare_prompts(self, y, y_lens, codes, nar_stage, y_prompts_codes, prefix_mode):
+ # 5.1 For the NAR acoustic prompt tokens, we select a random segment waveform of 3 seconds
+ # from the same utterance.
+ # We implement this differently.
+ if prefix_mode == 0:
+ # no prefix
+ prefix_len = 0
+ y_emb = self.nar_audio_embeddings[0](y)
+ for j in range(1, nar_stage):
+ # Formula (4) (5)
+ y_emb = y_emb + self.nar_audio_embeddings[j](codes[..., j])
+ elif prefix_mode == 1:
+ # prefix at begining
+ int_low = (0.25 * y_lens.min()).type(torch.int64).item()
+ prefix_len = torch.randint(0, int_low * 2, size=()).item()
+ prefix_len = min(prefix_len, 225) # 24000/320 * 3s = 225 frames
+
+ y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len])
+ y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:])
+ for j in range(1, self.num_quantizers):
+ y_prompts += self.nar_audio_embeddings[j](
+ codes[:, :prefix_len, j]
+ )
+ if j < nar_stage:
+ y_emb += self.nar_audio_embeddings[j](
+ codes[:, prefix_len:, j]
+ )
+ y_emb = torch.concat([y_prompts, y_emb], axis=1)
+ elif prefix_mode in [2, 4]:
+ if prefix_mode == 2:
+ # random prefix
+ prefix_len = min(225, int(0.25 * y_lens.min().item()))
+
+ y_prompts_codes = []
+ for b in range(codes.shape[0]):
+ start = self.rng.randint(0, y_lens[b].item() - prefix_len)
+ y_prompts_codes.append(
+ torch.clone(codes[b, start : start + prefix_len])
+ )
+ codes[
+ b, start : start + prefix_len, nar_stage
+ ] = NUM_AUDIO_TOKENS
+ y_prompts_codes = torch.stack(y_prompts_codes, dim=0)
+ else:
+ prefix_len = y_prompts_codes.shape[1]
+
+ y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0])
+ y_emb = self.nar_audio_embeddings[0](y)
+ for j in range(1, self.num_quantizers):
+ y_prompts += self.nar_audio_embeddings[j](
+ y_prompts_codes[..., j]
+ )
+ if j < nar_stage:
+ y_emb += self.nar_audio_embeddings[j](codes[..., j])
+ y_emb = torch.concat([y_prompts, y_emb], axis=1)
+ else:
+ raise ValueError
+
+ return y_emb, prefix_len
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_lens: torch.Tensor,
+ y: Union[torch.Tensor],
+ y_lens: Union[torch.Tensor],
+ reduction: str = "sum",
+ train_stage: int = 0,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
+ raise NotImplementedError
+
+ def inference(
+ self,
+ x: torch.Tensor,
+ x_lens: torch.Tensor,
+ y: torch.Tensor,
+ enroll_x_lens: Union[torch.Tensor, None] = None,
+ top_k: int = -100,
+ temperature: float = 1.0,
+ ) -> torch.Tensor:
+ raise NotImplementedError
+
+ def visualize(
+ self,
+ predicts: Tuple[torch.Tensor],
+ batch: Dict[str, Union[List, torch.Tensor]],
+ output_dir: str,
+ limit: int = 4,
+ ) -> None:
+ raise NotImplementedError
+
+
+class VALLE(VALLF):
+ """It implements https://arxiv.org/abs/2301.02111
+ "Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers"
+ """
+
+ def __init__(
+ self,
+ d_model: int,
+ nhead: int,
+ num_layers: int,
+ norm_first: bool = True,
+ add_prenet: bool = False,
+ prefix_mode: int = 0,
+ share_embedding: bool = True,
+ nar_scale_factor: float = 1.0,
+ **kwargs,
+ ):
+ """
+ Args:
+ d_model:
+ The number of expected features in the input (required).
+ nhead:
+ The number of heads in the multiheadattention models (required).
+ num_layers:
+ The number of sub-decoder-layers in the decoder (required).
+ """
+ super(VALLE, self).__init__(
+ d_model,
+ nhead,
+ num_layers,
+ norm_first=norm_first,
+ add_prenet=add_prenet,
+ decoder_cls=TransformerEncoder,
+ decoder_layer_cls=TransformerEncoderLayer,
+ prefix_mode=prefix_mode,
+ share_embedding=share_embedding,
+ nar_scale_factor=nar_scale_factor,
+ **kwargs,
+ )
+ self.language_ID = {
+ 'en': 0,
+ 'zh': 1,
+ 'ja': 2,
+ }
+ self.ar_language_embedding = TokenEmbedding(d_model, len(self.language_ID))
+ self.nar_language_embedding = TokenEmbedding(d_model, len(self.language_ID))
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_lens: torch.Tensor,
+ y: Union[torch.Tensor],
+ y_lens: Union[torch.Tensor],
+ reduction: str = "sum",
+ train_stage: int = 0,
+ **kwargs,
+ ):
+ raise NotImplementedError
+
+ def inference(
+ self,
+ x: torch.Tensor,
+ x_lens: torch.Tensor,
+ y: torch.Tensor,
+ enroll_x_lens: torch.Tensor,
+ top_k: int = -100,
+ temperature: float = 1.0,
+ prompt_language: str = None,
+ text_language: str = None,
+ ) -> torch.Tensor:
+ """
+ Args:
+ x:
+ A 2-D tensor of shape (1, S).
+ x_lens:
+ A 1-D tensor of shape (1,). It contains the number of tokens in `x`
+ before padding.
+ y:
+ A 3-D tensor of shape (1, T, 8).
+ top_k: (`optional`) int
+ The number of highest probability tokens to keep for top-k-filtering. Default to -100.
+ temperature: (`optional`) float
+ The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
+ Returns:
+ Return the predicted audio code matrix.
+ """
+ assert x.ndim == 2, x.shape
+ assert x_lens.ndim == 1, x_lens.shape
+ assert y.ndim == 3, y.shape
+ assert y.shape[0] == 1, y.shape
+
+ assert torch.all(x_lens > 0)
+
+ # NOTE: x has been padded in TextTokenCollater
+ text = x
+ x = self.ar_text_embedding(text)
+ # Add language embedding
+ prompt_language_id = torch.LongTensor(np.array([self.language_ID[prompt_language]])).to(x.device)
+ if isinstance(text_language, str):
+ text_language_id = torch.LongTensor(np.array([self.language_ID[text_language]])).to(x.device)
+ elif isinstance(text_language, List):
+ text_language_id = torch.LongTensor(np.array([self.language_ID[tl] for tl in text_language])).to(x.device)
+ x[:, :enroll_x_lens, :] += self.ar_language_embedding(prompt_language_id)
+ x[:, enroll_x_lens:, :] += self.ar_language_embedding(text_language_id)
+ x = self.ar_text_prenet(x)
+ x = self.ar_text_position(x)
+
+ text_len = x_lens.max()
+ prompts = y
+ prefix_len = y.shape[1]
+
+ # AR Decoder
+ # TODO: Managing decoder steps avoid repetitive computation
+ y = prompts[..., 0]
+ if self.ar_audio_prepend_bos:
+ y = F.pad(y, (1, 0), value=NUM_AUDIO_TOKENS + 1)
+
+ x_len = x_lens.max()
+ x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
+
+ kv_cache = None
+ use_kv_caching = True
+ while True:
+ y_emb = self.ar_audio_embedding(y)
+ y_emb = self.ar_audio_prenet(y_emb)
+ y_pos = self.ar_audio_position(y_emb)
+ xy_pos = torch.concat([x, y_pos], dim=1)
+
+ y_len = y.shape[1]
+ x_attn_mask_pad = F.pad(
+ x_attn_mask,
+ (0, y_len),
+ value=True,
+ )
+ y_attn_mask = F.pad(
+ torch.triu(
+ torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1
+ ),
+ (x_len, 0),
+ value=False,
+ )
+ xy_attn_mask = torch.concat(
+ [x_attn_mask_pad, y_attn_mask], dim=0
+ ).to(y.device)
+
+
+ if use_kv_caching and kv_cache is not None:
+ xy_pos = xy_pos[:, [-1]]
+ else:
+ pass
+
+ xy_dec, kv_cache = self.ar_decoder.infer(
+ xy_pos,
+ mask=xy_attn_mask,
+ past_kv=kv_cache,
+ use_cache=use_kv_caching,
+ )
+ # xy_dec, _ = self.ar_decoder(
+ # (xy_pos, None),
+ # mask=xy_attn_mask,
+ # )
+
+ logits = self.ar_predict_layer(xy_dec[:, -1])
+ samples = topk_sampling(
+ logits, top_k=top_k, top_p=1, temperature=temperature
+ )
+
+ if (
+ torch.argmax(logits, dim=-1)[0] == NUM_AUDIO_TOKENS
+ or samples[0, 0] == NUM_AUDIO_TOKENS
+ or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 16
+ ):
+ if prompts.shape[1] == y.shape[1]:
+ raise SyntaxError(
+ "well trained model shouldn't reach here."
+ )
+
+ print(f"VALL-E EOS [{prompts.shape[1]} -> {y.shape[1]}]")
+
+ memory_used = get_memory_usage()
+ print(f"Current memory used: {memory_used:.2f} MB")
+ break
+
+ # safety measure, break if token sequence too long
+ if y.shape[1] > 2250:
+ print(f"VALL-E EOS [{prompts.shape[1]} -> {y.shape[1]}]")
+ break
+
+ y = torch.concat([y, samples], dim=1)
+
+ codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]]
+ if self.num_quantizers == 1:
+ return torch.stack(codes, dim=-1)
+
+ # Non-AR Decoders
+ y_emb = self.nar_audio_embeddings[0](
+ y[:, int(self.ar_audio_prepend_bos) :]
+ )
+
+ if self.prefix_mode in [2, 4]: # Exclude enrolled_phonemes
+ enrolled_len = enroll_x_lens.max().item()
+ # SOS + Synthesis Text + EOS
+ text = torch.concat(
+ [
+ text[:, :1],
+ text[:, enrolled_len - 1 :],
+ ],
+ dim=1,
+ )
+ text_len = text_len - (enrolled_len - 2)
+ assert text.shape[0] == 1
+
+ x = self.nar_text_embedding(text)
+ # Add language embedding
+ prompt_language_id = torch.LongTensor(np.array([self.language_ID[prompt_language]])).to(x.device)
+ if isinstance(text_language, str):
+ text_language_id = torch.LongTensor(np.array([self.language_ID[text_language]])).to(x.device)
+ elif isinstance(text_language, List):
+ text_language_id = torch.LongTensor(np.array([self.language_ID[tl] for tl in text_language])).to(x.device)
+ x[:, :enroll_x_lens, :] += self.nar_language_embedding(prompt_language_id)
+ x[:, enroll_x_lens:, :] += self.nar_language_embedding(text_language_id)
+ x = self.nar_text_prenet(x)
+ x = self.nar_text_position(x)
+
+ if self.prefix_mode == 0:
+ for i, (predict_layer, embedding_layer) in enumerate(
+ zip(
+ self.nar_predict_layers,
+ self.nar_audio_embeddings[1:],
+ )
+ ):
+ y_pos = self.nar_audio_prenet(y_emb)
+ y_pos = self.nar_audio_position(y_pos)
+ xy_pos = torch.concat([x, y_pos], dim=1)
+
+ xy_dec, _ = self.nar_decoder(
+ (xy_pos, self.nar_stage_embeddings[i].weight)
+ )
+ logits = predict_layer(xy_dec[:, text_len + prefix_len :])
+
+ samples = torch.argmax(logits, dim=-1)
+ codes.append(samples)
+
+ if i < self.num_quantizers - 2:
+ y_emb[:, :prefix_len] += embedding_layer(
+ prompts[..., i + 1]
+ )
+ y_emb[:, prefix_len:] += embedding_layer(samples)
+ else:
+ for j in range(1, self.num_quantizers):
+ y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](
+ prompts[..., j]
+ )
+
+ for i, (predict_layer, embedding_layer) in enumerate(
+ zip(
+ self.nar_predict_layers,
+ self.nar_audio_embeddings[1:],
+ )
+ ):
+ y_pos = self.nar_audio_prenet(y_emb)
+ y_pos = self.nar_audio_position(y_pos)
+ xy_pos = torch.concat([x, y_pos], dim=1)
+
+ xy_dec, _ = self.nar_decoder(
+ (xy_pos, self.nar_stage_embeddings[i].weight)
+ )
+ logits = predict_layer(xy_dec[:, text_len + prefix_len :])
+
+ samples = torch.argmax(logits, dim=-1)
+ codes.append(samples)
+
+ if i < self.num_quantizers - 2:
+ y_emb[:, prefix_len:] += embedding_layer(samples)
+
+ assert len(codes) == self.num_quantizers
+ del text_language_id, prompt_language_id, y_emb, x, y_pos, xy_pos, xy_dec, logits, samples, kv_cache, x_attn_mask, y_attn_mask, xy_attn_mask
+ gc.collect()
+ return torch.stack(codes, dim=-1)
+
+ def continual(
+ self,
+ x: torch.Tensor,
+ x_lens: torch.Tensor,
+ y: torch.Tensor,
+ ) -> torch.Tensor:
+ """
+ Args:
+ x:
+ A 2-D tensor of shape (1, S).
+ x_lens:
+ A 1-D tensor of shape (1,). It contains the number of tokens in `x`
+ before padding.
+ y:
+ A 3-D tensor of shape (1, T, 8).
+ Returns:
+ Return the predicted audio code matrix.
+ """
+ assert x.ndim == 2, x.shape
+ assert x_lens.ndim == 1, x_lens.shape
+ assert y.ndim == 3, y.shape
+ assert y.shape[0] == 1, y.shape
+
+ assert torch.all(x_lens > 0)
+ assert self.num_quantizers == 8
+
+ # NOTE: x has been padded in TextTokenCollater
+ text = x
+ x = self.ar_text_embedding(text)
+ x = self.ar_text_prenet(x)
+ x = self.ar_text_position(x)
+
+ text_len = x_lens.max()
+
+ prefix_len = min(int(y.shape[1] * 0.5), 3 * 75)
+
+ # AR Decoder
+ prompts = y[:, :prefix_len]
+
+ codes = [y[:, prefix_len:, 0]]
+ # Non-AR Decoders
+ x = self.nar_text_embedding(text)
+ x = self.nar_text_prenet(x)
+ x = self.nar_text_position(x)
+
+ y_emb = self.nar_audio_embeddings[0](y[..., 0])
+
+ if self.prefix_mode == 0:
+ for i, (predict_layer, embedding_layer) in enumerate(
+ zip(
+ self.nar_predict_layers,
+ self.nar_audio_embeddings[1:],
+ )
+ ):
+ y_pos = self.nar_audio_position(y_emb)
+ y_pos = self.nar_audio_prenet(y_pos)
+ xy_pos = torch.concat([x, y_pos], dim=1)
+
+ xy_dec, _ = self.nar_decoder(
+ (xy_pos, self.nar_stage_embeddings[i].weight)
+ )
+ logits = predict_layer(xy_dec[:, text_len + prefix_len :])
+
+ samples = torch.argmax(logits, dim=-1)
+ codes.append(samples)
+
+ if i < 6:
+ y_emb[:, :prefix_len] += embedding_layer(
+ prompts[..., i + 1]
+ )
+ y_emb[:, prefix_len:] += embedding_layer(samples)
+ else:
+ for j in range(1, 8):
+ y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](
+ prompts[..., j]
+ )
+
+ for i, (predict_layer, embedding_layer) in enumerate(
+ zip(
+ self.nar_predict_layers,
+ self.nar_audio_embeddings[1:],
+ )
+ ):
+ y_pos = self.nar_audio_prenet(y_emb)
+ y_pos = self.nar_audio_position(y_pos)
+ xy_pos = torch.concat([x, y_pos], dim=1)
+
+ xy_dec, _ = self.nar_decoder(
+ (xy_pos, self.nar_stage_embeddings[i].weight)
+ )
+ logits = predict_layer(xy_dec[:, text_len + prefix_len :])
+
+ samples = torch.argmax(logits, dim=-1)
+ codes.append(samples)
+
+ if i < 6:
+ y_emb[:, prefix_len:] += embedding_layer(samples)
+
+ assert len(codes) == 8
+ return torch.stack(codes, dim=-1)
+
+
+# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
+def top_k_top_p_filtering(
+ logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
+):
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
+ Args:
+ logits: logits distribution shape (batch size, vocabulary size)
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
+ """
+ if top_k > 0:
+ top_k = min(
+ max(top_k, min_tokens_to_keep), logits.size(-1)
+ ) # Safety check
+ # Remove all tokens with a probability less than the last token of the top-k
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
+ logits[indices_to_remove] = filter_value
+
+ if top_p < 1.0:
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+ cumulative_probs = torch.cumsum(
+ F.softmax(sorted_logits, dim=-1), dim=-1
+ )
+
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
+ sorted_indices_to_remove = cumulative_probs > top_p
+ if min_tokens_to_keep > 1:
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
+ # Shift the indices to the right to keep also the first token above the threshold
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
+ ..., :-1
+ ].clone()
+ sorted_indices_to_remove[..., 0] = 0
+
+ # scatter sorted tensors to original indexing
+ indices_to_remove = sorted_indices_to_remove.scatter(
+ 1, sorted_indices, sorted_indices_to_remove
+ )
+ logits[indices_to_remove] = filter_value
+ return logits
+
+
+def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
+ # temperature: (`optional`) float
+ # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
+ # top_k: (`optional`) int
+ # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
+ # top_p: (`optional`) float
+ # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
+
+ # Temperature (higher temperature => more likely to sample low probability tokens)
+ if temperature != 1.0:
+ logits = logits / temperature
+ # Top-p/top-k filtering
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
+ # Sample
+ token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
+ return token
diff --git a/modules/__init__.py b/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/modules/activation.py b/modules/activation.py
new file mode 100644
index 0000000000000000000000000000000000000000..e51c8eb49a378b039ed819820444f784fdc52dba
--- /dev/null
+++ b/modules/activation.py
@@ -0,0 +1,612 @@
+from typing import Optional, Tuple, List
+import math
+
+import torch
+from torch import Tensor
+from torch.nn import Linear, Module
+from torch.nn import functional as F
+from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
+from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
+from torch.nn.parameter import Parameter
+
+def _in_projection_packed(
+ q: Tensor,
+ k: Tensor,
+ v: Tensor,
+ w: Tensor,
+ b: Optional[Tensor] = None,
+) -> List[Tensor]:
+ r"""
+ Performs the in-projection step of the attention operation, using packed weights.
+ Output is a triple containing projection tensors for query, key and value.
+
+ Args:
+ q, k, v: query, key and value tensors to be projected. For self-attention,
+ these are typically the same tensor; for encoder-decoder attention,
+ k and v are typically the same tensor. (We take advantage of these
+ identities for performance if they are present.) Regardless, q, k and v
+ must share a common embedding dimension; otherwise their shapes may vary.
+ w: projection weights for q, k and v, packed into a single tensor. Weights
+ are packed along dimension 0, in q, k, v order.
+ b: optional projection biases for q, k and v, packed into a single tensor
+ in q, k, v order.
+
+ Shape:
+ Inputs:
+ - q: :math:`(..., E)` where E is the embedding dimension
+ - k: :math:`(..., E)` where E is the embedding dimension
+ - v: :math:`(..., E)` where E is the embedding dimension
+ - w: :math:`(E * 3, E)` where E is the embedding dimension
+ - b: :math:`E * 3` where E is the embedding dimension
+
+ Output:
+ - in output list :math:`[q', k', v']`, each output tensor will have the
+ same shape as the corresponding input tensor.
+ """
+ E = q.size(-1)
+ if k is v:
+ if q is k:
+ # self-attention
+ return F.linear(q, w, b).chunk(3, dim=-1)
+ else:
+ # encoder-decoder attention
+ w_q, w_kv = w.split([E, E * 2])
+ if b is None:
+ b_q = b_kv = None
+ else:
+ b_q, b_kv = b.split([E, E * 2])
+ return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1)
+ else:
+ w_q, w_k, w_v = w.chunk(3)
+ if b is None:
+ b_q = b_k = b_v = None
+ else:
+ b_q, b_k, b_v = b.chunk(3)
+ return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
+
+def _scaled_dot_product_attention(
+ q: Tensor,
+ k: Tensor,
+ v: Tensor,
+ attn_mask: Optional[Tensor] = None,
+ dropout_p: float = 0.0,
+) -> Tuple[Tensor, Tensor]:
+ r"""
+ Computes scaled dot product attention on query, key and value tensors, using
+ an optional attention mask if passed, and applying dropout if a probability
+ greater than 0.0 is specified.
+ Returns a tensor pair containing attended values and attention weights.
+
+ Args:
+ q, k, v: query, key and value tensors. See Shape section for shape details.
+ attn_mask: optional tensor containing mask values to be added to calculated
+ attention. May be 2D or 3D; see Shape section for details.
+ dropout_p: dropout probability. If greater than 0.0, dropout is applied.
+
+ Shape:
+ - q: :math:`(B, Nt, E)` where B is batch size, Nt is the target sequence length,
+ and E is embedding dimension.
+ - key: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
+ and E is embedding dimension.
+ - value: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
+ and E is embedding dimension.
+ - attn_mask: either a 3D tensor of shape :math:`(B, Nt, Ns)` or a 2D tensor of
+ shape :math:`(Nt, Ns)`.
+
+ - Output: attention values have shape :math:`(B, Nt, E)`; attention weights
+ have shape :math:`(B, Nt, Ns)`
+ """
+ B, Nt, E = q.shape
+ q = q / math.sqrt(E)
+ # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
+ if attn_mask is not None:
+ attn = torch.baddbmm(attn_mask, q, k.transpose(-2, -1))
+ else:
+ attn = torch.bmm(q, k.transpose(-2, -1))
+
+ attn = F.softmax(attn, dim=-1)
+ if dropout_p > 0.0:
+ attn = F.dropout(attn, p=dropout_p)
+ # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
+ output = torch.bmm(attn, v)
+ return output, attn
+
+def multi_head_attention_forward(
+ x,
+ ipw,
+ ipb,
+ opw,
+ opb,
+ n_head,
+ attn_mask,
+ past_kv=None,
+ use_cache=False,
+):
+ # x = x.transpose(1, 0)
+ # tgt_len, bsz, embed_dim = x.shape
+ # head_dim = embed_dim // n_head
+ # q, k, v = _in_projection_packed(x, x, x, ipw, ipb)
+ # q = q.contiguous().view(tgt_len, bsz * n_head, head_dim).transpose(0, 1)
+ # k = k.contiguous().view(k.shape[0], bsz * n_head, head_dim).transpose(0, 1)
+ # v = v.contiguous().view(v.shape[0], bsz * n_head, head_dim).transpose(0, 1)
+
+ # new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
+ # new_attn_mask.masked_fill_(attn_mask, float("-inf"))
+ # attn_mask = new_attn_mask
+ #
+ # attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, 0.0)
+ # attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
+ # attn_output = torch._C._nn.linear(attn_output, opw, opb)
+ # attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
+
+ B, T, C = x.size()
+
+ q, k, v = torch._C._nn.linear(x, ipw, ipb).chunk(3, dim=-1)
+ k = k.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs)
+ q = q.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs)
+ v = v.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs)
+ if past_kv is not None:
+ past_key = past_kv[0]
+ past_value = past_kv[1]
+ k = torch.cat((past_key, k), dim=-2)
+ v = torch.cat((past_value, v), dim=-2)
+
+ FULL_T = k.shape[-2]
+
+ if use_cache is True:
+ present = (k, v)
+ else:
+ present = None
+
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
+ att = att.masked_fill(attn_mask[FULL_T - T:FULL_T, :FULL_T], float('-inf'))
+ att = F.softmax(att, dim=-1)
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
+ y = torch._C._nn.linear(y, opw, opb)
+ return (y, present)
+
+
+class MultiheadAttention(Module):
+ r"""Allows the model to jointly attend to information
+ from different representation subspaces as described in the paper:
+ `Attention Is All You Need `_.
+
+ Multi-Head Attention is defined as:
+
+ .. math::
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
+
+ where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
+
+ ``forward()`` will use a special optimized implementation if all of the following
+ conditions are met:
+
+ - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
+ restriction will be loosened in the future.)
+ - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
+ - training is disabled (using ``.eval()``)
+ - dropout is 0
+ - ``add_bias_kv`` is ``False``
+ - ``add_zero_attn`` is ``False``
+ - ``batch_first`` is ``True`` and the input is batched
+ - ``kdim`` and ``vdim`` are equal to ``embed_dim``
+ - at most one of ``key_padding_mask`` or ``attn_mask`` is passed
+ - if a `NestedTensor `_ is passed, neither ``key_padding_mask``
+ nor ``attn_mask`` is passed
+
+ If the optimized implementation is in use, a
+ `NestedTensor `_ can be passed for
+ ``query``/``key``/``value`` to represent padding more efficiently than using a
+ padding mask. In this case, a `NestedTensor `_
+ will be returned, and an additional speedup proportional to the fraction of the input
+ that is padding can be expected.
+
+ Args:
+ embed_dim: Total dimension of the model.
+ num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
+ across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
+ dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
+ bias: If specified, adds bias to input / output projection layers. Default: ``True``.
+ add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
+ add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
+ Default: ``False``.
+ kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
+ vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
+ batch_first: If ``True``, then the input and output tensors are provided
+ as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
+
+ Examples::
+
+ >>> # xdoctest: +SKIP
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
+
+ """
+ __constants__ = ["batch_first"]
+ bias_k: Optional[torch.Tensor]
+ bias_v: Optional[torch.Tensor]
+
+ def __init__(
+ self,
+ embed_dim,
+ num_heads,
+ dropout=0.0,
+ bias=True,
+ add_bias_kv=False,
+ add_zero_attn=False,
+ kdim=None,
+ vdim=None,
+ batch_first=False,
+ linear1_cls=Linear,
+ linear2_cls=Linear,
+ device=None,
+ dtype=None,
+ ) -> None:
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super(MultiheadAttention, self).__init__()
+ self.embed_dim = embed_dim
+ self.kdim = kdim if kdim is not None else embed_dim
+ self.vdim = vdim if vdim is not None else embed_dim
+ self._qkv_same_embed_dim = (
+ self.kdim == embed_dim and self.vdim == embed_dim
+ )
+
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.batch_first = batch_first
+ self.head_dim = embed_dim // num_heads
+ assert (
+ self.head_dim * num_heads == self.embed_dim
+ ), "embed_dim must be divisible by num_heads"
+
+ if add_bias_kv:
+ self.bias_k = Parameter(
+ torch.empty((1, 1, embed_dim), **factory_kwargs)
+ )
+ self.bias_v = Parameter(
+ torch.empty((1, 1, embed_dim), **factory_kwargs)
+ )
+ else:
+ self.bias_k = self.bias_v = None
+
+ if linear1_cls == Linear:
+ if not self._qkv_same_embed_dim:
+ self.q_proj_weight = Parameter(
+ torch.empty((embed_dim, embed_dim), **factory_kwargs)
+ )
+ self.k_proj_weight = Parameter(
+ torch.empty((embed_dim, self.kdim), **factory_kwargs)
+ )
+ self.v_proj_weight = Parameter(
+ torch.empty((embed_dim, self.vdim), **factory_kwargs)
+ )
+ self.register_parameter("in_proj_weight", None)
+ else:
+ self.in_proj_weight = Parameter(
+ torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
+ )
+ self.register_parameter("q_proj_weight", None)
+ self.register_parameter("k_proj_weight", None)
+ self.register_parameter("v_proj_weight", None)
+
+ if bias:
+ self.in_proj_bias = Parameter(
+ torch.empty(3 * embed_dim, **factory_kwargs)
+ )
+ else:
+ self.register_parameter("in_proj_bias", None)
+ self.out_proj = NonDynamicallyQuantizableLinear(
+ embed_dim, embed_dim, bias=bias, **factory_kwargs
+ )
+
+ self._reset_parameters()
+ else:
+ if not self._qkv_same_embed_dim:
+ raise NotImplementedError
+ else:
+ self.in_proj_linear = linear1_cls(
+ embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
+ )
+ self.in_proj_weight = self.in_proj_linear.weight
+
+ self.register_parameter("q_proj_weight", None)
+ self.register_parameter("k_proj_weight", None)
+ self.register_parameter("v_proj_weight", None)
+
+ if bias:
+ self.in_proj_bias = self.in_proj_linear.bias
+ else:
+ self.register_parameter("in_proj_bias", None)
+
+ self.out_proj = linear2_cls(
+ embed_dim, embed_dim, bias=bias, **factory_kwargs
+ )
+
+ if self.bias_k is not None:
+ xavier_normal_(self.bias_k)
+ if self.bias_v is not None:
+ xavier_normal_(self.bias_v)
+
+ self.add_zero_attn = add_zero_attn
+
+ def _reset_parameters(self):
+ if self._qkv_same_embed_dim:
+ xavier_uniform_(self.in_proj_weight)
+ else:
+ xavier_uniform_(self.q_proj_weight)
+ xavier_uniform_(self.k_proj_weight)
+ xavier_uniform_(self.v_proj_weight)
+
+ if self.in_proj_bias is not None:
+ constant_(self.in_proj_bias, 0.0)
+ constant_(self.out_proj.bias, 0.0)
+
+ if self.bias_k is not None:
+ xavier_normal_(self.bias_k)
+ if self.bias_v is not None:
+ xavier_normal_(self.bias_v)
+
+ def __setstate__(self, state):
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
+ if "_qkv_same_embed_dim" not in state:
+ state["_qkv_same_embed_dim"] = True
+
+ super(MultiheadAttention, self).__setstate__(state)
+
+ def forward(
+ self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ need_weights: bool = True,
+ attn_mask: Optional[Tensor] = None,
+ average_attn_weights: bool = True,
+ ) -> Tuple[Tensor, Optional[Tensor]]:
+ r"""
+ Args:
+ query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
+ or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
+ :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
+ Queries are compared against key-value pairs to produce the output.
+ See "Attention Is All You Need" for more details.
+ key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
+ or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
+ :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
+ See "Attention Is All You Need" for more details.
+ value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
+ ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
+ sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
+ See "Attention Is All You Need" for more details.
+ key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
+ to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
+ Binary and byte masks are supported.
+ For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
+ the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
+ need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
+ Default: ``True``.
+ attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
+ :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
+ :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
+ broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
+ Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
+ corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
+ corresponding position is not allowed to attend. For a float mask, the mask values will be added to
+ the attention weight.
+ average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
+ heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
+ effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
+
+ Outputs:
+ - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
+ :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
+ where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
+ embedding dimension ``embed_dim``.
+ - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
+ returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
+ :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
+ :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
+ head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
+
+ .. note::
+ `batch_first` argument is ignored for unbatched inputs.
+ """
+ is_batched = query.dim() == 3
+ if key_padding_mask is not None:
+ _kpm_dtype = key_padding_mask.dtype
+ if _kpm_dtype != torch.bool and not torch.is_floating_point(
+ key_padding_mask
+ ):
+ raise AssertionError(
+ "only bool and floating types of key_padding_mask are supported"
+ )
+ why_not_fast_path = ""
+ if not is_batched:
+ why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
+ elif query is not key or key is not value:
+ # When lifting this restriction, don't forget to either
+ # enforce that the dtypes all match or test cases where
+ # they don't!
+ why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
+ elif (
+ self.in_proj_bias is not None
+ and query.dtype != self.in_proj_bias.dtype
+ ):
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
+ elif (
+ self.in_proj_weight is not None
+ and query.dtype != self.in_proj_weight.dtype
+ ):
+ # this case will fail anyway, but at least they'll get a useful error message.
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
+ elif self.training:
+ why_not_fast_path = "training is enabled"
+ elif not self.batch_first:
+ why_not_fast_path = "batch_first was not True"
+ elif self.bias_k is not None:
+ why_not_fast_path = "self.bias_k was not None"
+ elif self.bias_v is not None:
+ why_not_fast_path = "self.bias_v was not None"
+ elif self.dropout:
+ why_not_fast_path = f"dropout was {self.dropout}, required zero"
+ elif self.add_zero_attn:
+ why_not_fast_path = "add_zero_attn was enabled"
+ elif not self._qkv_same_embed_dim:
+ why_not_fast_path = "_qkv_same_embed_dim was not True"
+ elif attn_mask is not None:
+ why_not_fast_path = "attn_mask was not None"
+ elif query.is_nested and key_padding_mask is not None:
+ why_not_fast_path = (
+ "key_padding_mask is not supported with NestedTensor input"
+ )
+ elif self.num_heads % 2 == 1:
+ why_not_fast_path = "num_heads is odd"
+ elif torch.is_autocast_enabled():
+ why_not_fast_path = "autocast is enabled"
+
+ if not why_not_fast_path:
+ tensor_args = (
+ query,
+ key,
+ value,
+ self.in_proj_weight,
+ self.in_proj_bias,
+ self.out_proj.weight,
+ self.out_proj.bias,
+ )
+ # We have to use list comprehensions below because TorchScript does not support
+ # generator expressions.
+ if torch.overrides.has_torch_function(tensor_args):
+ why_not_fast_path = "some Tensor argument has_torch_function"
+ elif not all(
+ [
+ (x is None or x.is_cuda or "cpu" in str(x.device))
+ for x in tensor_args
+ ]
+ ):
+ why_not_fast_path = (
+ "some Tensor argument is neither CUDA nor CPU"
+ )
+ elif torch.is_grad_enabled() and any(
+ [x is not None and x.requires_grad for x in tensor_args]
+ ):
+ why_not_fast_path = (
+ "grad is enabled and at least one of query or the "
+ "input/output projection weights or biases requires_grad"
+ )
+ if not why_not_fast_path:
+ return torch._native_multi_head_attention(
+ query,
+ key,
+ value,
+ self.embed_dim,
+ self.num_heads,
+ self.in_proj_weight,
+ self.in_proj_bias,
+ self.out_proj.weight,
+ self.out_proj.bias,
+ key_padding_mask
+ if key_padding_mask is not None
+ else attn_mask,
+ need_weights,
+ average_attn_weights,
+ 1
+ if key_padding_mask is not None
+ else 0
+ if attn_mask is not None
+ else None,
+ )
+
+ any_nested = query.is_nested or key.is_nested or value.is_nested
+ assert not any_nested, (
+ "MultiheadAttention does not support NestedTensor outside of its fast path. "
+ + f"The fast path was not hit because {why_not_fast_path}"
+ )
+
+ if self.batch_first and is_batched:
+ # make sure that the transpose op does not affect the "is" property
+ if key is value:
+ if query is key:
+ query = key = value = query.transpose(1, 0)
+ else:
+ query, key = [x.transpose(1, 0) for x in (query, key)]
+ value = key
+ else:
+ query, key, value = [
+ x.transpose(1, 0) for x in (query, key, value)
+ ]
+
+ if not self._qkv_same_embed_dim:
+ attn_output, attn_output_weights = F.multi_head_attention_forward(
+ query,
+ key,
+ value,
+ self.embed_dim,
+ self.num_heads,
+ self.in_proj_weight,
+ self.in_proj_bias,
+ self.bias_k,
+ self.bias_v,
+ self.add_zero_attn,
+ self.dropout,
+ self.out_proj.weight,
+ self.out_proj.bias,
+ training=self.training,
+ key_padding_mask=key_padding_mask,
+ need_weights=need_weights,
+ attn_mask=attn_mask,
+ use_separate_proj_weight=True,
+ q_proj_weight=self.q_proj_weight,
+ k_proj_weight=self.k_proj_weight,
+ v_proj_weight=self.v_proj_weight,
+ average_attn_weights=average_attn_weights,
+ )
+ else:
+ attn_output, attn_output_weights = F.multi_head_attention_forward(
+ query,
+ key,
+ value,
+ self.embed_dim,
+ self.num_heads,
+ self.in_proj_weight,
+ self.in_proj_bias,
+ self.bias_k,
+ self.bias_v,
+ self.add_zero_attn,
+ self.dropout,
+ self.out_proj.weight,
+ self.out_proj.bias,
+ training=self.training,
+ key_padding_mask=key_padding_mask,
+ need_weights=need_weights,
+ attn_mask=attn_mask,
+ average_attn_weights=average_attn_weights,
+ )
+ if self.batch_first and is_batched:
+ return attn_output.transpose(1, 0), attn_output_weights
+ else:
+ return attn_output, attn_output_weights
+
+ def infer(self,
+ x: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ need_weights: bool = True,
+ attn_mask: Optional[Tensor] = None,
+ average_attn_weights: bool = True,
+ past_kv = None,
+ use_cache = False
+ ):
+ # x = x.transpose(1, 0)
+ y, kv = multi_head_attention_forward(
+ x=x,
+ ipw=self.in_proj_weight,
+ ipb=self.in_proj_bias,
+ opw=self.out_proj.weight,
+ opb=self.out_proj.bias,
+ n_head=self.num_heads,
+ attn_mask=attn_mask,
+ past_kv=past_kv,
+ use_cache=use_cache,
+ )
+ return (y, kv)
diff --git a/modules/embedding.py b/modules/embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..17f6c316da3de6a432f4d43f9563800fdb6d58c4
--- /dev/null
+++ b/modules/embedding.py
@@ -0,0 +1,97 @@
+# Copyright 2023 (authors: Feiteng Li)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+
+import torch
+import torch.nn as nn
+
+
+class TokenEmbedding(nn.Module):
+ def __init__(
+ self,
+ dim_model: int,
+ vocab_size: int,
+ dropout: float = 0.0,
+ ):
+ super().__init__()
+
+ self.vocab_size = vocab_size
+ self.dim_model = dim_model
+
+ self.dropout = torch.nn.Dropout(p=dropout)
+ self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model)
+
+ @property
+ def weight(self) -> torch.Tensor:
+ return self.word_embeddings.weight
+
+ def embedding(self, index: int) -> torch.Tensor:
+ return self.word_embeddings.weight[index : index + 1]
+
+ def forward(self, x: torch.Tensor):
+ X = self.word_embeddings(x)
+ X = self.dropout(X)
+
+ return X
+
+
+class SinePositionalEmbedding(nn.Module):
+ def __init__(
+ self,
+ dim_model: int,
+ dropout: float = 0.0,
+ scale: bool = False,
+ alpha: bool = False,
+ ):
+ super().__init__()
+ self.dim_model = dim_model
+ self.x_scale = math.sqrt(dim_model) if scale else 1.0
+ self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
+ self.dropout = torch.nn.Dropout(p=dropout)
+
+ self.reverse = False
+ self.pe = None
+ self.extend_pe(torch.tensor(0.0).expand(1, 4000))
+
+ def extend_pe(self, x):
+ """Reset the positional encodings."""
+ if self.pe is not None:
+ if self.pe.size(1) >= x.size(1):
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+ return
+ pe = torch.zeros(x.size(1), self.dim_model)
+ if self.reverse:
+ position = torch.arange(
+ x.size(1) - 1, -1, -1.0, dtype=torch.float32
+ ).unsqueeze(1)
+ else:
+ position = torch.arange(
+ 0, x.size(1), dtype=torch.float32
+ ).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, self.dim_model, 2, dtype=torch.float32)
+ * -(math.log(10000.0) / self.dim_model)
+ )
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(0)
+ self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ self.extend_pe(x)
+ output = x.unsqueeze(-1) if x.ndim == 2 else x
+ output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
+ return self.dropout(output)
diff --git a/modules/scaling.py b/modules/scaling.py
new file mode 100644
index 0000000000000000000000000000000000000000..824a2077cedb787dd05bbad5ba6fe65099e11fcf
--- /dev/null
+++ b/modules/scaling.py
@@ -0,0 +1,1401 @@
+# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import collections
+import logging
+import random
+import math
+from functools import reduce
+from itertools import repeat
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+from torch.nn import Embedding as ScaledEmbedding
+
+from utils import Transpose
+
+
+class ActivationBalancerFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x: Tensor,
+ scale_factor: Tensor,
+ sign_factor: Optional[Tensor],
+ channel_dim: int,
+ ) -> Tensor:
+ if channel_dim < 0:
+ channel_dim += x.ndim
+ ctx.channel_dim = channel_dim
+ xgt0 = x > 0
+ if sign_factor is None:
+ ctx.save_for_backward(xgt0, scale_factor)
+ else:
+ ctx.save_for_backward(xgt0, scale_factor, sign_factor)
+ return x
+
+ @staticmethod
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
+ if len(ctx.saved_tensors) == 3:
+ xgt0, scale_factor, sign_factor = ctx.saved_tensors
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
+ scale_factor = scale_factor.unsqueeze(-1)
+ sign_factor = sign_factor.unsqueeze(-1)
+ factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
+ else:
+ xgt0, scale_factor = ctx.saved_tensors
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
+ scale_factor = scale_factor.unsqueeze(-1)
+ factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
+ neg_delta_grad = x_grad.abs() * factor
+ return (
+ x_grad - neg_delta_grad,
+ None,
+ None,
+ None,
+ )
+
+
+def _compute_scale_factor(
+ x: Tensor,
+ channel_dim: int,
+ min_abs: float,
+ max_abs: float,
+ gain_factor: float,
+ max_factor: float,
+) -> Tensor:
+ if channel_dim < 0:
+ channel_dim += x.ndim
+ sum_dims = [d for d in range(x.ndim) if d != channel_dim]
+ x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
+
+ if min_abs == 0.0:
+ below_threshold = 0.0
+ else:
+ # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
+ # x_abs)_mean , min_abs.
+ below_threshold = (
+ (min_abs - x_abs_mean) * (gain_factor / min_abs)
+ ).clamp(min=0, max=max_factor)
+
+ above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
+ min=0, max=max_factor
+ )
+
+ return below_threshold - above_threshold
+
+
+def _compute_sign_factor(
+ x: Tensor,
+ channel_dim: int,
+ min_positive: float,
+ max_positive: float,
+ gain_factor: float,
+ max_factor: float,
+) -> Tensor:
+ if channel_dim < 0:
+ channel_dim += x.ndim
+ sum_dims = [d for d in range(x.ndim) if d != channel_dim]
+ proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims)
+ if min_positive == 0.0:
+ factor1 = 0.0
+ else:
+ # 0 if proportion_positive >= min_positive, else can be
+ # as large as max_factor.
+ factor1 = (
+ (min_positive - proportion_positive) * (gain_factor / min_positive)
+ ).clamp_(min=0, max=max_factor)
+
+ if max_positive == 1.0:
+ factor2 = 0.0
+ else:
+ # 0 if self.proportion_positive <= max_positive, else can be
+ # as large as -max_factor.
+ factor2 = (
+ (proportion_positive - max_positive)
+ * (gain_factor / (1.0 - max_positive))
+ ).clamp_(min=0, max=max_factor)
+ sign_factor = factor1 - factor2
+ # require min_positive != 0 or max_positive != 1:
+ assert not isinstance(sign_factor, float)
+ return sign_factor
+
+
+class ActivationScaleBalancerFunction(torch.autograd.Function):
+ """
+ This object is used in class ActivationBalancer when the user specified
+ min_positive=0, max_positive=1, so there are no constraints on the signs
+ of the activations and only the absolute value has a constraint.
+ """
+
+ @staticmethod
+ def forward(
+ ctx,
+ x: Tensor,
+ sign_factor: Tensor,
+ scale_factor: Tensor,
+ channel_dim: int,
+ ) -> Tensor:
+ if channel_dim < 0:
+ channel_dim += x.ndim
+ ctx.channel_dim = channel_dim
+ xgt0 = x > 0
+ ctx.save_for_backward(xgt0, sign_factor, scale_factor)
+ return x
+
+ @staticmethod
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
+ xgt0, sign_factor, scale_factor = ctx.saved_tensors
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
+ sign_factor = sign_factor.unsqueeze(-1)
+ scale_factor = scale_factor.unsqueeze(-1)
+
+ factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
+ neg_delta_grad = x_grad.abs() * factor
+ return (
+ x_grad - neg_delta_grad,
+ None,
+ None,
+ None,
+ )
+
+
+class RandomClampFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x: Tensor,
+ min: Optional[float],
+ max: Optional[float],
+ prob: float,
+ reflect: float,
+ ) -> Tensor:
+ x_clamped = torch.clamp(x, min=min, max=max)
+ mask = torch.rand_like(x) < prob
+ ans = torch.where(mask, x_clamped, x)
+ if x.requires_grad:
+ ctx.save_for_backward(ans == x)
+ ctx.reflect = reflect
+ if reflect != 0.0:
+ ans = ans * (1.0 + reflect) - (x * reflect)
+ return ans
+
+ @staticmethod
+ def backward(
+ ctx, ans_grad: Tensor
+ ) -> Tuple[Tensor, None, None, None, None]:
+ (is_same,) = ctx.saved_tensors
+ x_grad = ans_grad * is_same.to(ans_grad.dtype)
+ reflect = ctx.reflect
+ if reflect != 0.0:
+ x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect)
+ return x_grad, None, None, None, None
+
+
+def random_clamp(
+ x: Tensor,
+ min: Optional[float] = None,
+ max: Optional[float] = None,
+ prob: float = 0.5,
+ reflect: float = 0.0,
+):
+ return RandomClampFunction.apply(x, min, max, prob, reflect)
+
+
+def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor:
+ """
+ A randomized way of casting a floating point value to half precision.
+ """
+ if x.dtype == torch.float16:
+ return x
+ x_abs = x.abs()
+ is_too_small = x_abs < min_abs
+ # for elements where is_too_small is true, random_val will contain +-min_abs with
+ # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations,
+ # for those elements].
+ random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
+ return torch.where(is_too_small, random_val, x).to(torch.float16)
+
+
+class RandomGradFunction(torch.autograd.Function):
+ """
+ Does nothing in forward pass; in backward pass, gets rid of very small grads using
+ randomized approach that preserves expectations (intended to reduce roundoff).
+ """
+
+ @staticmethod
+ def forward(ctx, x: Tensor, min_abs: float) -> Tensor:
+ ctx.min_abs = min_abs
+ return x
+
+ @staticmethod
+ def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]:
+ if ans_grad.dtype == torch.float16:
+ return (
+ random_cast_to_half(
+ ans_grad.to(torch.float32), min_abs=ctx.min_abs
+ ),
+ None,
+ )
+ else:
+ return ans_grad, None
+
+
+class RandomGrad(torch.nn.Module):
+ """
+ Gets rid of very small gradients using an expectation-preserving method, intended to increase
+ accuracy of training when using amp (automatic mixed precision)
+ """
+
+ def __init__(self, min_abs: float = 5.0e-06):
+ super(RandomGrad, self).__init__()
+ self.min_abs = min_abs
+
+ def forward(self, x: Tensor):
+ if (
+ torch.jit.is_scripting()
+ or not self.training
+ or torch.jit.is_tracing()
+ ):
+ return x
+ else:
+ return RandomGradFunction.apply(x, self.min_abs)
+
+
+class SoftmaxFunction(torch.autograd.Function):
+ """
+ Tries to handle half-precision derivatives in a randomized way that should
+ be more accurate for training than the default behavior.
+ """
+
+ @staticmethod
+ def forward(ctx, x: Tensor, dim: int):
+ ans = x.softmax(dim=dim)
+ # if x dtype is float16, x.softmax() returns a float32 because
+ # (presumably) that op does not support float16, and autocast
+ # is enabled.
+ if torch.is_autocast_enabled():
+ ans = ans.to(torch.float16)
+ ctx.save_for_backward(ans)
+ ctx.x_dtype = x.dtype
+ ctx.dim = dim
+ return ans
+
+ @staticmethod
+ def backward(ctx, ans_grad: Tensor):
+ (ans,) = ctx.saved_tensors
+ with torch.cuda.amp.autocast(enabled=False):
+ ans_grad = ans_grad.to(torch.float32)
+ ans = ans.to(torch.float32)
+ x_grad = ans_grad * ans
+ x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
+ return x_grad, None
+
+
+def softmax(x: Tensor, dim: int):
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ return x.softmax(dim)
+
+ return SoftmaxFunction.apply(x, dim)
+
+
+class MaxEigLimiterFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x: Tensor,
+ coeffs: Tensor,
+ direction: Tensor,
+ channel_dim: int,
+ grad_scale: float,
+ ) -> Tensor:
+ ctx.channel_dim = channel_dim
+ ctx.grad_scale = grad_scale
+ ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach())
+ return x
+
+ @staticmethod
+ def backward(ctx, x_grad, *args):
+ with torch.enable_grad():
+ (x_orig, coeffs, new_direction) = ctx.saved_tensors
+ x_orig.requires_grad = True
+ num_channels = x_orig.shape[ctx.channel_dim]
+ x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
+ new_direction.requires_grad = False
+ x = x - x.mean(dim=0)
+ x_var = (x ** 2).mean()
+ x_residual = x - coeffs * new_direction
+ x_residual_var = (x_residual ** 2).mean()
+ # `variance_proportion` is the proportion of the variance accounted for
+ # by the top eigen-direction. This is to be minimized.
+ variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
+ variance_proportion.backward()
+ x_orig_grad = x_orig.grad
+ x_extra_grad = (
+ x_orig.grad
+ * ctx.grad_scale
+ * x_grad.norm()
+ / (x_orig_grad.norm() + 1.0e-20)
+ )
+ return x_grad + x_extra_grad.detach(), None, None, None, None
+
+
+class BasicNorm(torch.nn.Module):
+ """
+ This is intended to be a simpler, and hopefully cheaper, replacement for
+ LayerNorm. The observation this is based on, is that Transformer-type
+ networks, especially with pre-norm, sometimes seem to set one of the
+ feature dimensions to a large constant value (e.g. 50), which "defeats"
+ the LayerNorm because the output magnitude is then not strongly dependent
+ on the other (useful) features. Presumably the weight and bias of the
+ LayerNorm are required to allow it to do this.
+
+ So the idea is to introduce this large constant value as an explicit
+ parameter, that takes the role of the "eps" in LayerNorm, so the network
+ doesn't have to do this trick. We make the "eps" learnable.
+
+ Args:
+ num_channels: the number of channels, e.g. 512.
+ channel_dim: the axis/dimension corresponding to the channel,
+ interprted as an offset from the input's ndim if negative.
+ shis is NOT the num_channels; it should typically be one of
+ {-2, -1, 0, 1, 2, 3}.
+ eps: the initial "epsilon" that we add as ballast in:
+ scale = ((input_vec**2).mean() + epsilon)**-0.5
+ Note: our epsilon is actually large, but we keep the name
+ to indicate the connection with conventional LayerNorm.
+ learn_eps: if true, we learn epsilon; if false, we keep it
+ at the initial value.
+ eps_min: float
+ eps_max: float
+ """
+
+ def __init__(
+ self,
+ num_channels: int,
+ channel_dim: int = -1, # CAUTION: see documentation.
+ eps: float = 0.25,
+ learn_eps: bool = True,
+ eps_min: float = -3.0,
+ eps_max: float = 3.0,
+ ) -> None:
+ super(BasicNorm, self).__init__()
+ self.num_channels = num_channels
+ self.channel_dim = channel_dim
+ if learn_eps:
+ self.eps = nn.Parameter(torch.tensor(eps).log().detach())
+ else:
+ self.register_buffer("eps", torch.tensor(eps).log().detach())
+ self.eps_min = eps_min
+ self.eps_max = eps_max
+
+ def forward(self, x: Tensor) -> Tensor:
+ assert x.shape[self.channel_dim] == self.num_channels
+ eps = self.eps
+ if self.training and random.random() < 0.25:
+ # with probability 0.25, in training mode, clamp eps between the min
+ # and max; this will encourage it to learn parameters within the
+ # allowed range by making parameters that are outside the allowed
+ # range noisy.
+
+ # gradients to allow the parameter to get back into the allowed
+ # region if it happens to exit it.
+ eps = eps.clamp(min=self.eps_min, max=self.eps_max)
+ scales = (
+ torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp()
+ ) ** -0.5
+ return x * scales
+
+
+def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
+ """
+ Behaves like a constructor of a modified version of nn.Linear
+ that gives an easy way to set the default initial parameter scale.
+
+ Args:
+ Accepts the standard args and kwargs that nn.Linear accepts
+ e.g. in_features, out_features, bias=False.
+
+ initial_scale: you can override this if you want to increase
+ or decrease the initial magnitude of the module's output
+ (affects the initialization of weight_scale and bias_scale).
+ Another option, if you want to do something like this, is
+ to re-initialize the parameters.
+ """
+ ans = nn.Linear(*args, **kwargs)
+ with torch.no_grad():
+ ans.weight[:] *= initial_scale
+ if ans.bias is not None:
+ torch.nn.init.uniform_(
+ ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
+ )
+ return ans
+
+
+def ScaledConv1d(
+ *args,
+ initial_scale: float = 1.0,
+ kernel_size: int = 3,
+ padding: str = "same",
+ **kwargs,
+) -> nn.Conv1d:
+ """
+ Behaves like a constructor of a modified version of nn.Conv1d
+ that gives an easy way to set the default initial parameter scale.
+
+ Args:
+ Accepts the standard args and kwargs that nn.Linear accepts
+ e.g. in_features, out_features, bias=False.
+
+ initial_scale: you can override this if you want to increase
+ or decrease the initial magnitude of the module's output
+ (affects the initialization of weight_scale and bias_scale).
+ Another option, if you want to do something like this, is
+ to re-initialize the parameters.
+ """
+ ans = nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs)
+ with torch.no_grad():
+ ans.weight[:] *= initial_scale
+ if ans.bias is not None:
+ torch.nn.init.uniform_(
+ ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
+ )
+ return ans
+
+
+def TransposeScaledConv1d(
+ *args,
+ initial_scale: float = 1.0,
+ kernel_size: int = 3,
+ padding: str = "same",
+ **kwargs,
+) -> nn.Sequential:
+ """
+ Transpose -> ScaledConv1d
+ """
+ return nn.Sequential(
+ Transpose(),
+ ScaledConv1d(
+ *args,
+ initial_scale=initial_scale,
+ kernel_size=kernel_size,
+ padding=padding,
+ **kwargs,
+ ),
+ )
+
+
+def ScaledConv1dTranspose(
+ *args,
+ initial_scale: float = 1.0,
+ kernel_size: int = 3,
+ padding: str = "same",
+ **kwargs,
+) -> nn.Sequential:
+ """
+ Transpose -> ScaledConv1d
+ """
+ return nn.Sequential(
+ ScaledConv1d(
+ *args,
+ initial_scale=initial_scale,
+ kernel_size=kernel_size,
+ padding=padding,
+ **kwargs,
+ ),
+ Transpose(),
+ )
+
+
+def TransposeConv1d(
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
+) -> nn.Sequential:
+ """
+ Transpose -> Conv1d
+ """
+ return nn.Sequential(
+ Transpose(),
+ nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
+ )
+
+
+def Conv1dTranspose(
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
+) -> nn.Sequential:
+ """
+ ScaledConv1d -> Transpose
+ """
+ return nn.Sequential(
+ nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
+ Transpose(),
+ )
+
+
+class SRLinear(nn.Linear):
+ """https://arxiv.org/abs/2303.06296
+ Stabilizing Transformer Training by Preventing Attention Entropy Collapse
+ """
+
+ def __init__(self, in_features, out_features, bias=True, **kwargs):
+ super().__init__(in_features, out_features, bias=bias, **kwargs)
+ self.register_buffer(
+ "u", nn.functional.normalize(torch.randn(in_features), dim=0)
+ )
+ with torch.no_grad():
+ sigma = self.get_sigma()
+ self.register_buffer("spectral_norm", sigma)
+ self.sigma = nn.Parameter(torch.ones(1))
+
+ def get_sigma(self):
+ with torch.no_grad():
+ u = self.u
+ v = self.weight.mv(u)
+ v = nn.functional.normalize(v, dim=0)
+ u = self.weight.T.mv(v)
+ u = nn.functional.normalize(u, dim=0)
+ self.u.data.copy_(u)
+ return torch.einsum("c,cd,d->", v, self.weight, u)
+
+ def get_weight(self):
+ sigma = self.get_sigma()
+ if self.training:
+ self.spectral_norm.data.copy_(sigma)
+ weight = (self.sigma / sigma) * self.weight
+ return weight
+
+ def forward(self, x):
+ return nn.functional.linear(x, self.get_weight(), self.bias)
+
+
+class SRConv1d(SRLinear):
+ def __init__(
+ self,
+ in_features,
+ out_features,
+ kernel_size,
+ stride: int = 1,
+ padding: str = "same",
+ bias: bool = True,
+ **kwargs,
+ ):
+ in_features = in_features * kernel_size
+ super().__init__(in_features, out_features, bias=bias, **kwargs)
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.padding = padding
+
+ def forward(self, x):
+ in_features = self.in_features // self.kernel_size
+ weight = self.get_weight().view(
+ self.out_features, in_features, self.kernel_size
+ )
+ return nn.functional.conv1d(
+ x, weight, bias=self.bias, stride=self.stride, padding=self.padding
+ )
+
+
+def TransposeSRConv1d(
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
+) -> nn.Sequential:
+ """
+ Transpose -> SRConv1d
+ """
+ return nn.Sequential(
+ Transpose(),
+ SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
+ )
+
+
+def SRConv1dTranspose(
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
+) -> nn.Sequential:
+ """
+ SRConv1d -> Transpose
+ """
+ return nn.Sequential(
+ SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
+ Transpose(),
+ )
+
+
+class ActivationBalancer(torch.nn.Module):
+ """
+ Modifies the backpropped derivatives of a function to try to encourage, for
+ each channel, that it is positive at least a proportion `threshold` of the
+ time. It does this by multiplying negative derivative values by up to
+ (1+max_factor), and positive derivative values by up to (1-max_factor),
+ interpolated from 1 at the threshold to those extremal values when none
+ of the inputs are positive.
+
+ Args:
+ num_channels: the number of channels
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
+ min_positive: the minimum, per channel, of the proportion of the time
+ that (x > 0), below which we start to modify the derivatives.
+ max_positive: the maximum, per channel, of the proportion of the time
+ that (x > 0), above which we start to modify the derivatives.
+ max_factor: the maximum factor by which we modify the derivatives for
+ either the sign constraint or the magnitude constraint;
+ e.g. with max_factor=0.02, the the derivatives would be multiplied by
+ values in the range [0.98..1.02].
+ sign_gain_factor: determines the 'gain' with which we increase the
+ change in gradient once the constraints on min_positive and max_positive
+ are violated.
+ scale_gain_factor: determines the 'gain' with which we increase the
+ change in gradient once the constraints on min_abs and max_abs
+ are violated.
+ min_abs: the minimum average-absolute-value difference from the mean
+ value per channel, which we allow, before we start to modify
+ the derivatives to prevent this.
+ max_abs: the maximum average-absolute-value difference from the mean
+ value per channel, which we allow, before we start to modify
+ the derivatives to prevent this.
+ min_prob: determines the minimum probability with which we modify the
+ gradients for the {min,max}_positive and {min,max}_abs constraints,
+ on each forward(). This is done randomly to prevent all layers
+ from doing it at the same time. Early in training we may use
+ higher probabilities than this; it will decay to this value.
+ """
+
+ def __init__(
+ self,
+ num_channels: int,
+ channel_dim: int,
+ min_positive: float = 0.05,
+ max_positive: float = 0.95,
+ max_factor: float = 0.04,
+ sign_gain_factor: float = 0.01,
+ scale_gain_factor: float = 0.02,
+ min_abs: float = 0.2,
+ max_abs: float = 100.0,
+ min_prob: float = 0.1,
+ ):
+ super(ActivationBalancer, self).__init__()
+ self.num_channels = num_channels
+ self.channel_dim = channel_dim
+ self.min_positive = min_positive
+ self.max_positive = max_positive
+ self.max_factor = max_factor
+ self.min_abs = min_abs
+ self.max_abs = max_abs
+ self.min_prob = min_prob
+ self.sign_gain_factor = sign_gain_factor
+ self.scale_gain_factor = scale_gain_factor
+
+ # count measures how many times the forward() function has been called.
+ # We occasionally sync this to a tensor called `count`, that exists to
+ # make sure it is synced to disk when we load and save the model.
+ self.cpu_count = 0
+ self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
+
+ def forward(self, x: Tensor) -> Tensor:
+ if (
+ torch.jit.is_scripting()
+ or not x.requires_grad
+ or torch.jit.is_tracing()
+ ):
+ return _no_op(x)
+
+ count = self.cpu_count
+ self.cpu_count += 1
+
+ if random.random() < 0.01:
+ # Occasionally sync self.cpu_count with self.count.
+ # count affects the decay of 'prob'. don't do this on every iter,
+ # because syncing with the GPU is slow.
+ self.cpu_count = max(self.cpu_count, self.count.item())
+ self.count.fill_(self.cpu_count)
+
+ # the prob of doing some work exponentially decreases from 0.5 till it hits
+ # a floor at min_prob (==0.1, by default)
+ prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0)))
+
+ if random.random() < prob:
+ sign_gain_factor = 0.5
+ if self.min_positive != 0.0 or self.max_positive != 1.0:
+ sign_factor = _compute_sign_factor(
+ x,
+ self.channel_dim,
+ self.min_positive,
+ self.max_positive,
+ gain_factor=self.sign_gain_factor / prob,
+ max_factor=self.max_factor,
+ )
+ else:
+ sign_factor = None
+
+ scale_factor = _compute_scale_factor(
+ x.detach(),
+ self.channel_dim,
+ min_abs=self.min_abs,
+ max_abs=self.max_abs,
+ gain_factor=self.scale_gain_factor / prob,
+ max_factor=self.max_factor,
+ )
+ return ActivationBalancerFunction.apply(
+ x,
+ scale_factor,
+ sign_factor,
+ self.channel_dim,
+ )
+ else:
+ return _no_op(x)
+
+
+def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
+ """
+ Returns x unmodified, but in backprop will put a penalty for the excess of
+ the absolute values of elements of x over the limit "limit". E.g. if
+ limit == 10.0, then if x has any values over 10 it will get a penalty.
+
+ Caution: the value of this penalty will be affected by grad scaling used
+ in automatic mixed precision training. For this reasons we use this,
+ it shouldn't really matter, or may even be helpful; we just use this
+ to disallow really implausible values of scores to be given to softmax.
+ """
+ x_sign = x.sign()
+ over_limit = (x.abs() - limit) > 0
+ # The following is a memory efficient way to penalize the absolute values of
+ # x that's over the limit. (The memory efficiency comes when you think
+ # about which items torch needs to cache for the autograd, and which ones it
+ # can throw away). The numerical value of aux_loss as computed here will
+ # actually be larger than it should be, by limit * over_limit.sum(), but it
+ # has the same derivative as the real aux_loss which is penalty * (x.abs() -
+ # limit).relu().
+ aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
+ # note: we don't do sum() here on aux)_loss, but it's as if we had done
+ # sum() due to how with_loss() works.
+ x = with_loss(x, aux_loss)
+ # you must use x for something, or this will be ineffective.
+ return x
+
+
+def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
+ if x.ndim == 2:
+ return x.diag()
+ else:
+ (batch, dim, dim) = x.shape
+ x = x.reshape(batch, dim * dim)
+ x = x[:, :: dim + 1]
+ assert x.shape == (batch, dim)
+ return x
+
+
+def _whitening_metric(x: Tensor, num_groups: int):
+ """
+ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of
+ of the centered feature covariance are the same within each group's covariance matrix
+ and also between groups.
+ Args:
+ x: a Tensor of shape (*, num_channels)
+ num_groups: the number of groups of channels, a number >=1 that divides num_channels
+ Returns:
+ Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and
+ greater than 1.0 otherwise.
+ """
+ assert x.dtype != torch.float16
+ x = x.reshape(-1, x.shape[-1])
+ (num_frames, num_channels) = x.shape
+ assert num_channels % num_groups == 0
+ channels_per_group = num_channels // num_groups
+ x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1)
+ # x now has shape (num_groups, num_frames, channels_per_group)
+ # subtract the mean so we use the centered, not uncentered, covariance.
+ # My experience has been that when we "mess with the gradients" like this,
+ # it's better not do anything that tries to move the mean around, because
+ # that can easily cause instability.
+ x = x - x.mean(dim=1, keepdim=True)
+ # x_covar: (num_groups, channels_per_group, channels_per_group)
+ x_covar = torch.matmul(x.transpose(1, 2), x)
+ x_covar_mean_diag = _diag(x_covar).mean()
+ # the following expression is what we'd get if we took the matrix product
+ # of each covariance and measured the mean of its trace, i.e.
+ # the same as _diag(torch.matmul(x_covar, x_covar)).mean().
+ x_covarsq_mean_diag = (x_covar ** 2).sum() / (
+ num_groups * channels_per_group
+ )
+ # this metric will be >= 1.0; the larger it is, the less 'white' the data was.
+ metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20)
+ return metric
+
+
+class WhiteningPenaltyFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x: Tensor,
+ num_groups: int,
+ whitening_limit: float,
+ grad_scale: float,
+ ) -> Tensor:
+ ctx.save_for_backward(x)
+ ctx.num_groups = num_groups
+ ctx.whitening_limit = whitening_limit
+ ctx.grad_scale = grad_scale
+ return x
+
+ @staticmethod
+ def backward(ctx, x_grad: Tensor):
+ (x_orig,) = ctx.saved_tensors
+ with torch.enable_grad():
+ with torch.cuda.amp.autocast(enabled=False):
+ x_detached = x_orig.to(torch.float32).detach()
+ x_detached.requires_grad = True
+
+ metric = _whitening_metric(x_detached, ctx.num_groups)
+
+ if random.random() < 0.005 or __name__ == "__main__":
+ logging.info(
+ f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, "
+ f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}"
+ )
+
+ (metric - ctx.whitening_limit).relu().backward()
+ penalty_grad = x_detached.grad
+ scale = ctx.grad_scale * (
+ x_grad.to(torch.float32).norm()
+ / (penalty_grad.norm() + 1.0e-20)
+ )
+ penalty_grad = penalty_grad * scale
+ return x_grad + penalty_grad.to(x_grad.dtype), None, None, None
+
+
+class Whiten(nn.Module):
+ def __init__(
+ self,
+ num_groups: int,
+ whitening_limit: float,
+ prob: Union[float, Tuple[float, float]],
+ grad_scale: float,
+ ):
+ """
+ Args:
+ num_groups: the number of groups to divide the channel dim into before
+ whitening. We will attempt to make the feature covariance
+ within each group, after mean subtraction, as "white" as possible,
+ while having the same trace across all groups.
+ whitening_limit: a value greater than 1.0, that dictates how much
+ freedom we have to violate the constraints. 1.0 would mean perfectly
+ white, with exactly the same trace across groups; larger values
+ give more freedom. E.g. 2.0.
+ prob: the probability with which we apply the gradient modification
+ (also affects the grad scale). May be supplied as a float,
+ or as a pair (min_prob, max_prob)
+
+ grad_scale: determines the scale on the gradient term from this object,
+ relative to the rest of the gradient on the attention weights.
+ E.g. 0.02 (you may want to use smaller values than this if prob is large)
+ """
+ super(Whiten, self).__init__()
+ assert num_groups >= 1
+ assert whitening_limit >= 1
+ assert grad_scale >= 0
+ self.num_groups = num_groups
+ self.whitening_limit = whitening_limit
+ if isinstance(prob, float):
+ assert 0 < prob <= 1
+ self.prob = prob
+ else:
+ (self.min_prob, self.max_prob) = prob
+ assert 0 < self.min_prob < self.max_prob <= 1
+ self.prob = self.max_prob
+
+ self.grad_scale = grad_scale
+
+ def forward(self, x: Tensor) -> Tensor:
+ """
+ In the forward pass, this function just returns the input unmodified.
+ In the backward pass, it will modify the gradients to ensure that the
+ distribution in each group has close to (lambda times I) as the covariance
+ after mean subtraction, with the same lambda across groups.
+ For whitening_limit > 1, there will be more freedom to violate this
+ constraint.
+
+ Args:
+ x: the input of shape (*, num_channels)
+
+ Returns:
+ x, unmodified. You should make sure
+ you use the returned value, or the graph will be freed
+ and nothing will happen in backprop.
+ """
+ if (
+ not x.requires_grad
+ or random.random() > self.prob
+ or self.grad_scale == 0
+ ):
+ return _no_op(x)
+ else:
+ if hasattr(self, "min_prob") and random.random() < 0.25:
+ # occasionally switch between min_prob and max_prob, based on whether
+ # we are above or below the threshold.
+ if (
+ _whitening_metric(x.to(torch.float32), self.num_groups)
+ > self.whitening_limit
+ ):
+ # there would be a change to the grad.
+ self.prob = self.max_prob
+ else:
+ self.prob = self.min_prob
+
+ return WhiteningPenaltyFunction.apply(
+ x, self.num_groups, self.whitening_limit, self.grad_scale
+ )
+
+
+class WithLoss(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x: Tensor, y: Tensor):
+ ctx.y_shape = y.shape
+ return x
+
+ @staticmethod
+ def backward(ctx, ans_grad: Tensor):
+ return ans_grad, torch.ones(
+ ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device
+ )
+
+
+def with_loss(x, y):
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ return x
+ # returns x but adds y.sum() to the loss function.
+ return WithLoss.apply(x, y)
+
+
+def _no_op(x: Tensor) -> Tensor:
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ return x
+ else:
+ # a no-op function that will have a node in the autograd graph,
+ # to avoid certain bugs relating to backward hooks
+ return x.chunk(1, dim=-1)[0]
+
+
+class Identity(torch.nn.Module):
+ def __init__(self):
+ super(Identity, self).__init__()
+
+ def forward(self, x):
+ return _no_op(x)
+
+
+class MaxEig(torch.nn.Module):
+ """
+ Modifies the backpropped derivatives of a function to try to discourage
+ that any given direction in activation space accounts for more than
+ a specified proportion of the covariance (e.g. 0.2).
+
+
+ Args:
+ num_channels: the number of channels
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
+ max_var_per_eig: the maximum proportion of the variance of the
+ features/channels, after mean subtraction, that can come from
+ any given eigenvalue.
+ min_prob: the minimum probability with which we apply this during any invocation
+ of forward(), assuming last time we applied the constraint it was
+ not active; supplied for speed.
+ scale: determines the scale with which we modify the gradients, relative
+ to the existing / unmodified gradients
+ """
+
+ def __init__(
+ self,
+ num_channels: int,
+ channel_dim: int,
+ max_var_per_eig: float = 0.2,
+ min_prob: float = 0.01,
+ scale: float = 0.01,
+ ):
+ super(MaxEig, self).__init__()
+ self.num_channels = num_channels
+ self.channel_dim = channel_dim
+ self.scale = scale
+ assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels
+ self.max_var_per_eig = max_var_per_eig
+
+ # we figure out the dominant direction using the power method: starting with
+ # a random vector, keep multiplying by the covariance and renormalizing.
+ with torch.no_grad():
+ # arbitrary.. would use randn() but want to leave the rest of the model's
+ # random parameters unchanged for comparison
+ direction = torch.arange(num_channels).to(torch.float)
+ direction = direction / direction.norm()
+ self.register_buffer("max_eig_direction", direction)
+
+ self.min_prob = min_prob
+ # cur_prob is the current probability we'll use to apply the ActivationBalancer.
+ # We'll regress this towards prob, each tiem we try to apply it and it is not
+ # active.
+ self.cur_prob = 1.0
+
+ def forward(self, x: Tensor) -> Tensor:
+ if (
+ torch.jit.is_scripting()
+ or self.max_var_per_eig <= 0
+ or random.random() > self.cur_prob
+ or torch.jit.is_tracing()
+ ):
+ return _no_op(x)
+
+ with torch.cuda.amp.autocast(enabled=False):
+ eps = 1.0e-20
+ orig_x = x
+ x = x.to(torch.float32)
+ with torch.no_grad():
+ x = x.transpose(self.channel_dim, -1).reshape(
+ -1, self.num_channels
+ )
+ x = x - x.mean(dim=0)
+ new_direction, coeffs = self._find_direction_coeffs(
+ x, self.max_eig_direction
+ )
+ x_var = (x ** 2).mean()
+ x_residual = x - coeffs * new_direction
+ x_residual_var = (x_residual ** 2).mean()
+
+ # `variance_proportion` is the proportion of the variance accounted for
+ # by the top eigen-direction.
+ variance_proportion = (x_var - x_residual_var) / (
+ x_var + 1.0e-20
+ )
+
+ # ensure new direction is nonzero even if x == 0, by including `direction`.
+ self._set_direction(
+ 0.1 * self.max_eig_direction + new_direction
+ )
+
+ if random.random() < 0.01 or __name__ == "__main__":
+ logging.info(
+ f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}"
+ )
+
+ if variance_proportion >= self.max_var_per_eig:
+ # The constraint is active. Note, we should quite rarely
+ # reach here, only near the beginning of training if we are
+ # starting to diverge, should this constraint be active.
+ cur_prob = self.cur_prob
+ self.cur_prob = (
+ 1.0 # next time, do the update with probability 1.0.
+ )
+ return MaxEigLimiterFunction.apply(
+ orig_x, coeffs, new_direction, self.channel_dim, self.scale
+ )
+ else:
+ # let self.cur_prob exponentially approach self.min_prob, as
+ # long as the constraint is inactive.
+ self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob
+ return orig_x
+
+ def _set_direction(self, direction: Tensor):
+ """
+ Sets self.max_eig_direction to a normalized version of `direction`
+ """
+ direction = direction.detach()
+ direction = direction / direction.norm()
+ direction_sum = direction.sum().item()
+ if direction_sum - direction_sum == 0: # no inf/nan
+ self.max_eig_direction[:] = direction
+ else:
+ logging.info(
+ f"Warning: sum of direction in MaxEig is {direction_sum}, "
+ "num_channels={self.num_channels}, channel_dim={self.channel_dim}"
+ )
+
+ def _find_direction_coeffs(
+ self, x: Tensor, prev_direction: Tensor
+ ) -> Tuple[Tensor, Tensor, Tensor]:
+ """
+ Figure out (an approximation to) the proportion of the variance of a set of
+ feature vectors that can be attributed to the top eigen-direction.
+ Args:
+ x: a Tensor of shape (num_frames, num_channels), with num_frames > 1.
+ prev_direction: a Tensor of shape (num_channels,), that is our previous estimate
+ of the top eigen-direction, or a random direction if this is the first
+ iteration. Does not have to be normalized, but should be nonzero.
+
+ Returns: (cur_direction, coeffs), where:
+ cur_direction: a Tensor of shape (num_channels,) that is the current
+ estimate of the top eigen-direction.
+ coeffs: a Tensor of shape (num_frames, 1) that minimizes, or
+ approximately minimizes, (x - coeffs * cur_direction).norm()
+ """
+ (num_frames, num_channels) = x.shape
+ assert num_channels > 1 and num_frames > 1
+ assert prev_direction.shape == (num_channels,)
+ # `coeffs` are the coefficients of `prev_direction` in x.
+ # actually represent the coeffs up to a constant positive factor.
+ coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10
+ cur_direction = (x * coeffs).sum(dim=0) / (
+ (coeffs ** 2).sum() + 1.0e-20
+ )
+ return cur_direction, coeffs
+
+
+class DoubleSwishFunction(torch.autograd.Function):
+ """
+ double_swish(x) = x * torch.sigmoid(x-1)
+ This is a definition, originally motivated by its close numerical
+ similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
+
+ Memory-efficient derivative computation:
+ double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
+ double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
+ Now, s'(x) = s(x) * (1-s(x)).
+ double_swish'(x) = x * s'(x) + s(x).
+ = x * s(x) * (1-s(x)) + s(x).
+ = double_swish(x) * (1-s(x)) + s(x)
+ ... so we just need to remember s(x) but not x itself.
+ """
+
+ @staticmethod
+ def forward(ctx, x: Tensor) -> Tensor:
+ requires_grad = x.requires_grad
+ x_dtype = x.dtype
+ if x.dtype == torch.float16:
+ x = x.to(torch.float32)
+
+ s = torch.sigmoid(x - 1.0)
+ y = x * s
+
+ if requires_grad:
+ deriv = y * (1 - s) + s
+ # notes on derivative of x * sigmoid(x - 1):
+ # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
+ # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund
+ # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
+ # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
+ # floors), should be expectation-preserving.
+ floor = -0.043637
+ ceil = 1.2
+ d_scaled = (deriv - floor) * (
+ 255.0 / (ceil - floor)
+ ) + torch.rand_like(deriv)
+ if __name__ == "__main__":
+ # for self-testing only.
+ assert d_scaled.min() >= 0.0
+ assert d_scaled.max() < 256.0
+ d_int = d_scaled.to(torch.uint8)
+ ctx.save_for_backward(d_int)
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
+ y = y.to(torch.float16)
+ return y
+
+ @staticmethod
+ def backward(ctx, y_grad: Tensor) -> Tensor:
+ (d,) = ctx.saved_tensors
+ # the same constants as used in forward pass.
+ floor = -0.043637
+ ceil = 1.2
+ d = d * ((ceil - floor) / 255.0) + floor
+ return y_grad * d
+
+
+class DoubleSwish(torch.nn.Module):
+ def forward(self, x: Tensor) -> Tensor:
+ """Return double-swish activation function which is an approximation to Swish(Swish(x)),
+ that we approximate closely with x * sigmoid(x-1).
+ """
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
+ return x * torch.sigmoid(x - 1.0)
+ return DoubleSwishFunction.apply(x)
+
+
+def BalancedDoubleSwish(
+ d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
+) -> nn.Sequential:
+ """
+ ActivationBalancer -> DoubleSwish
+ """
+ balancer = ActivationBalancer(
+ d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
+ )
+ return nn.Sequential(
+ balancer,
+ DoubleSwish(),
+ )
+
+
+def _test_max_eig():
+ for proportion in [0.1, 0.5, 10.0]:
+ logging.info(f"proportion = {proportion}")
+ x = torch.randn(100, 128)
+ direction = torch.randn(128)
+ coeffs = torch.randn(100, 1)
+ x += proportion * direction * coeffs
+
+ x.requires_grad = True
+
+ num_channels = 128
+ m = MaxEig(
+ num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig
+ ) # grad_scale
+
+ for _ in range(4):
+ y = m(x)
+
+ y_grad = torch.randn_like(x)
+ y.backward(gradient=y_grad)
+
+ if proportion < 0.2:
+ assert torch.allclose(x.grad, y_grad, atol=1.0e-02)
+ elif proportion > 1.0:
+ assert not torch.allclose(x.grad, y_grad)
+
+
+def _test_whiten():
+ for proportion in [0.1, 0.5, 10.0]:
+ logging.info(f"_test_whiten(): proportion = {proportion}")
+ x = torch.randn(100, 128)
+ direction = torch.randn(128)
+ coeffs = torch.randn(100, 1)
+ x += proportion * direction * coeffs
+
+ x.requires_grad = True
+
+ num_channels = 128
+ m = Whiten(
+ 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit,
+ ) # grad_scale
+
+ for _ in range(4):
+ y = m(x)
+
+ y_grad = torch.randn_like(x)
+ y.backward(gradient=y_grad)
+
+ if proportion < 0.2:
+ assert torch.allclose(x.grad, y_grad)
+ elif proportion > 1.0:
+ assert not torch.allclose(x.grad, y_grad)
+
+
+def _test_activation_balancer_sign():
+ probs = torch.arange(0, 1, 0.01)
+ N = 1000
+ x = 1.0 * (
+ (2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0
+ )
+ x = x.detach()
+ x.requires_grad = True
+ m = ActivationBalancer(
+ probs.numel(),
+ channel_dim=0,
+ min_positive=0.05,
+ max_positive=0.95,
+ max_factor=0.2,
+ min_abs=0.0,
+ )
+
+ y_grad = torch.sign(torch.randn(probs.numel(), N))
+
+ y = m(x)
+ y.backward(gradient=y_grad)
+ print("_test_activation_balancer_sign: x = ", x)
+ print("_test_activation_balancer_sign: y grad = ", y_grad)
+ print("_test_activation_balancer_sign: x grad = ", x.grad)
+
+
+def _test_activation_balancer_magnitude():
+ magnitudes = torch.arange(0, 1, 0.01)
+ N = 1000
+ x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(
+ -1
+ )
+ x = x.detach()
+ x.requires_grad = True
+ m = ActivationBalancer(
+ magnitudes.numel(),
+ channel_dim=0,
+ min_positive=0.0,
+ max_positive=1.0,
+ max_factor=0.2,
+ min_abs=0.2,
+ max_abs=0.8,
+ min_prob=1.0,
+ )
+
+ y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
+
+ y = m(x)
+ y.backward(gradient=y_grad)
+ print("_test_activation_balancer_magnitude: x = ", x)
+ print("_test_activation_balancer_magnitude: y grad = ", y_grad)
+ print("_test_activation_balancer_magnitude: x grad = ", x.grad)
+
+
+def _test_basic_norm():
+ num_channels = 128
+ m = BasicNorm(num_channels=num_channels, channel_dim=1)
+
+ x = torch.randn(500, num_channels)
+
+ y = m(x)
+
+ assert y.shape == x.shape
+ x_rms = (x ** 2).mean().sqrt()
+ y_rms = (y ** 2).mean().sqrt()
+ print("x rms = ", x_rms)
+ print("y rms = ", y_rms)
+ assert y_rms < x_rms
+ assert y_rms > 0.5 * x_rms
+
+
+def _test_double_swish_deriv():
+ x = torch.randn(10, 12, dtype=torch.double) * 3.0
+ x.requires_grad = True
+ m = DoubleSwish()
+
+ tol = (1.2 - (-0.043637)) / 255.0
+ torch.autograd.gradcheck(m, x, atol=tol)
+
+ # for self-test.
+ x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
+ x.requires_grad = True
+ y = m(x)
+
+
+def _test_softmax():
+ a = torch.randn(2, 10, dtype=torch.float64)
+ b = a.clone()
+ a.requires_grad = True
+ b.requires_grad = True
+ a.softmax(dim=1)[:, 0].sum().backward()
+ print("a grad = ", a.grad)
+ softmax(b, dim=1)[:, 0].sum().backward()
+ print("b grad = ", b.grad)
+ assert torch.allclose(a.grad, b.grad)
+
+
+if __name__ == "__main__":
+ logging.getLogger().setLevel(logging.INFO)
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+ _test_softmax()
+ _test_whiten()
+ _test_max_eig()
+ _test_activation_balancer_sign()
+ _test_activation_balancer_magnitude()
+ _test_basic_norm()
+ _test_double_swish_deriv()
diff --git a/modules/transformer.py b/modules/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea8826b193c5053cb8ae74312f65ac95fe440350
--- /dev/null
+++ b/modules/transformer.py
@@ -0,0 +1,683 @@
+import copy
+import numbers
+from functools import partial
+from typing import Any, Callable, List, Optional, Tuple, Union
+
+import torch
+from torch import Tensor, nn
+from torch.nn import functional as F
+
+from .activation import MultiheadAttention
+from .scaling import ActivationBalancer, BalancedDoubleSwish
+from .scaling import BasicNorm as _BasicNorm
+
+_shape_t = Union[int, List[int], torch.Size]
+
+
+class LayerNorm(nn.Module):
+ __constants__ = ["normalized_shape", "eps", "elementwise_affine"]
+ normalized_shape: Tuple[int, ...]
+ eps: float
+ elementwise_affine: bool
+
+ def __init__(
+ self,
+ normalized_shape: _shape_t,
+ eps: float = 1e-5,
+ elementwise_affine: bool = True,
+ device=None,
+ dtype=None,
+ ) -> None:
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super(LayerNorm, self).__init__()
+ if isinstance(normalized_shape, numbers.Integral):
+ # mypy error: incompatible types in assignment
+ normalized_shape = (normalized_shape,) # type: ignore[assignment]
+ self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
+ self.eps = eps
+ self.elementwise_affine = elementwise_affine
+ if self.elementwise_affine:
+ self.weight = nn.Parameter(
+ torch.empty(self.normalized_shape, **factory_kwargs)
+ )
+ self.bias = nn.Parameter(
+ torch.empty(self.normalized_shape, **factory_kwargs)
+ )
+ else:
+ self.register_parameter("weight", None)
+ self.register_parameter("bias", None)
+
+ self.reset_parameters()
+
+ def reset_parameters(self) -> None:
+ if self.elementwise_affine:
+ nn.init.ones_(self.weight)
+ nn.init.zeros_(self.bias)
+
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
+ if isinstance(input, tuple):
+ input, embedding = input
+ return (
+ F.layer_norm(
+ input,
+ self.normalized_shape,
+ self.weight,
+ self.bias,
+ self.eps,
+ ),
+ embedding,
+ )
+
+ assert embedding is None
+ return F.layer_norm(
+ input, self.normalized_shape, self.weight, self.bias, self.eps
+ )
+
+ def extra_repr(self) -> str:
+ return (
+ "{normalized_shape}, eps={eps}, "
+ "elementwise_affine={elementwise_affine}".format(**self.__dict__)
+ )
+
+
+class AdaptiveLayerNorm(nn.Module):
+ r"""Adaptive Layer Normalization"""
+
+ def __init__(self, d_model, norm) -> None:
+ super(AdaptiveLayerNorm, self).__init__()
+ self.project_layer = nn.Linear(d_model, 2 * d_model)
+ self.norm = norm
+ self.d_model = d_model
+ self.eps = self.norm.eps
+
+ def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
+ if isinstance(input, tuple):
+ input, embedding = input
+ weight, bias = torch.split(
+ self.project_layer(embedding),
+ split_size_or_sections=self.d_model,
+ dim=-1,
+ )
+ return (weight * self.norm(input) + bias, embedding)
+
+ weight, bias = torch.split(
+ self.project_layer(embedding),
+ split_size_or_sections=self.d_model,
+ dim=-1,
+ )
+ return weight * self.norm(input) + bias
+
+
+class BasicNorm(_BasicNorm):
+ def __init__(
+ self,
+ d_model: int,
+ eps: float = 1e-5,
+ device=None,
+ dtype=None,
+ ):
+ super(BasicNorm, self).__init__(d_model, eps=eps)
+
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
+ if isinstance(input, tuple):
+ input, embedding = input
+ return (
+ super(BasicNorm, self).forward(input),
+ embedding,
+ )
+
+ assert embedding is None
+ return super(BasicNorm, self).forward(input)
+
+
+class BalancedBasicNorm(nn.Module):
+ def __init__(
+ self,
+ d_model: int,
+ eps: float = 1e-5,
+ device=None,
+ dtype=None,
+ ):
+ super(BalancedBasicNorm, self).__init__()
+ self.balancer = ActivationBalancer(
+ d_model,
+ channel_dim=-1,
+ min_positive=0.45,
+ max_positive=0.55,
+ max_abs=6.0,
+ )
+ self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype)
+
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
+ if isinstance(input, tuple):
+ input, embedding = input
+ return self.norm((self.balancer(input), embedding))
+
+ assert embedding is None
+ return self.norm(self.balancer(input))
+
+
+class IdentityNorm(nn.Module):
+ def __init__(
+ self,
+ d_model: int,
+ eps: float = 1e-5,
+ device=None,
+ dtype=None,
+ ) -> None:
+ super(IdentityNorm, self).__init__()
+
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
+ if isinstance(input, tuple):
+ return input
+
+ assert embedding is None
+ return input
+
+
+class TransformerEncoderLayer(nn.Module):
+ __constants__ = ["batch_first", "norm_first"]
+
+ def __init__(
+ self,
+ d_model: int,
+ nhead: int,
+ dim_feedforward: int = 2048,
+ dropout: float = 0.1,
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
+ batch_first: bool = False,
+ norm_first: bool = False,
+ device=None,
+ dtype=None,
+ linear1_self_attention_cls: nn.Module = nn.Linear,
+ linear2_self_attention_cls: nn.Module = nn.Linear,
+ linear1_feedforward_cls: nn.Module = nn.Linear,
+ linear2_feedforward_cls: nn.Module = nn.Linear,
+ layer_norm_cls: nn.Module = LayerNorm,
+ layer_norm_eps: float = 1e-5,
+ adaptive_layer_norm=False,
+ ) -> None:
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super(TransformerEncoderLayer, self).__init__()
+ self.self_attn = MultiheadAttention(
+ d_model,
+ nhead,
+ dropout=dropout,
+ batch_first=batch_first,
+ linear1_cls=linear1_self_attention_cls,
+ linear2_cls=linear2_self_attention_cls,
+ **factory_kwargs,
+ )
+
+ # Implementation of Feedforward model
+ self.linear1 = linear1_feedforward_cls(
+ d_model, dim_feedforward, **factory_kwargs
+ )
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = linear2_feedforward_cls(
+ dim_feedforward, d_model, **factory_kwargs
+ )
+
+ self.norm_first = norm_first
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ # Legacy string support for activation function.
+ if isinstance(activation, str):
+ activation = _get_activation_fn(activation)
+ elif isinstance(activation, partial):
+ activation = activation(d_model)
+ elif activation == BalancedDoubleSwish:
+ activation = BalancedDoubleSwish(d_model)
+
+ # # We can't test self.activation in forward() in TorchScript,
+ # # so stash some information about it instead.
+ # if activation is F.relu or isinstance(activation, torch.nn.ReLU):
+ # self.activation_relu_or_gelu = 1
+ # elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
+ # self.activation_relu_or_gelu = 2
+ # else:
+ # self.activation_relu_or_gelu = 0
+ self.activation = activation
+
+ norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
+ if layer_norm_cls == IdentityNorm:
+ norm2 = BalancedBasicNorm(
+ d_model, eps=layer_norm_eps, **factory_kwargs
+ )
+ else:
+ norm2 = layer_norm_cls(
+ d_model, eps=layer_norm_eps, **factory_kwargs
+ )
+
+ if adaptive_layer_norm:
+ self.norm1 = AdaptiveLayerNorm(d_model, norm1)
+ self.norm2 = AdaptiveLayerNorm(d_model, norm2)
+ else:
+ self.norm1 = norm1
+ self.norm2 = norm2
+
+ def __setstate__(self, state):
+ super(TransformerEncoderLayer, self).__setstate__(state)
+ if not hasattr(self, "activation"):
+ self.activation = F.relu
+
+ def forward(
+ self,
+ src: Tensor,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ r"""Pass the input through the encoder layer.
+
+ Args:
+ src: the sequence to the encoder layer (required).
+ src_mask: the mask for the src sequence (optional).
+ src_key_padding_mask: the mask for the src keys per batch (optional).
+
+ Shape:
+ see the docs in Transformer class.
+ """
+ x, stage_embedding = src, None
+ is_src_tuple = False
+ if isinstance(src, tuple):
+ x, stage_embedding = src
+ is_src_tuple = True
+
+ if src_key_padding_mask is not None:
+ _skpm_dtype = src_key_padding_mask.dtype
+ if _skpm_dtype != torch.bool and not torch.is_floating_point(
+ src_key_padding_mask
+ ):
+ raise AssertionError(
+ "only bool and floating types of key_padding_mask are supported"
+ )
+
+ if self.norm_first:
+ x = x + self._sa_block(
+ self.norm1(x, stage_embedding),
+ src_mask,
+ src_key_padding_mask,
+ )
+ x = x + self._ff_block(self.norm2(x, stage_embedding))
+ else:
+ x = self.norm1(
+ x + self._sa_block(x, src_mask, src_key_padding_mask),
+ stage_embedding,
+ )
+ x = self.norm2(x + self._ff_block(x), stage_embedding)
+
+ if is_src_tuple:
+ return (x, stage_embedding)
+ return x
+
+ def infer(
+ self,
+ src: Tensor,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ past_kv: Optional[Tensor] = None,
+ use_cache: bool = False,
+ ):
+ x, stage_embedding = src, None
+ is_src_tuple = False
+ if isinstance(src, tuple):
+ x, stage_embedding = src
+ is_src_tuple = True
+
+ if src_key_padding_mask is not None:
+ _skpm_dtype = src_key_padding_mask.dtype
+ if _skpm_dtype != torch.bool and not torch.is_floating_point(
+ src_key_padding_mask
+ ):
+ raise AssertionError(
+ "only bool and floating types of key_padding_mask are supported"
+ )
+
+ if self.norm_first:
+ x_attn_out, kv = self.self_attn.infer(
+ self.norm1(x, stage_embedding),
+ attn_mask=src_mask,
+ key_padding_mask=src_key_padding_mask,
+ need_weights=False,
+ past_kv=past_kv,
+ use_cache=use_cache,
+ )
+ x = x + x_attn_out
+ x = x + self._ff_block(self.norm2(x, stage_embedding))
+
+ if is_src_tuple:
+ return (x, stage_embedding)
+ return (x, kv)
+
+ # self-attention block
+ def _sa_block(
+ self,
+ x: Tensor,
+ attn_mask: Optional[Tensor],
+ key_padding_mask: Optional[Tensor],
+ ) -> Tensor:
+ x = self.self_attn(
+ x,
+ x,
+ x,
+ attn_mask=attn_mask,
+ key_padding_mask=key_padding_mask,
+ need_weights=False,
+ )[0]
+ return self.dropout1(x)
+
+ # feed forward block
+ def _ff_block(self, x: Tensor) -> Tensor:
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
+ return self.dropout2(x)
+
+
+class TransformerEncoder(nn.Module):
+ r"""TransformerEncoder is a stack of N encoder layers. Users can build the
+ BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
+
+ Args:
+ encoder_layer: an instance of the TransformerEncoderLayer() class (required).
+ num_layers: the number of sub-encoder-layers in the encoder (required).
+ norm: the layer normalization component (optional).
+ enable_nested_tensor: if True, input will automatically convert to nested tensor
+ (and convert back on output). This will improve the overall performance of
+ TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
+
+ Examples::
+ >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
+ >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
+ >>> src = torch.rand(10, 32, 512)
+ >>> out = transformer_encoder(src)
+ """
+ __constants__ = ["norm"]
+
+ def __init__(self, encoder_layer, num_layers, norm=None):
+ super(TransformerEncoder, self).__init__()
+ self.layers = _get_clones(encoder_layer, num_layers)
+ self.num_layers = num_layers
+ self.norm = norm
+
+ def forward(
+ self,
+ src: Tensor,
+ mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ return_layer_states: bool = False,
+ ) -> Tensor:
+ r"""Pass the input through the encoder layers in turn.
+
+ Args:
+ src: the sequence to the encoder (required).
+ mask: the mask for the src sequence (optional).
+ src_key_padding_mask: the mask for the src keys per batch (optional).
+ return_layer_states: return layers' state (optional).
+
+ Shape:
+ see the docs in Transformer class.
+ """
+ if return_layer_states:
+ layer_states = [] # layers' output
+ output = src
+ for mod in self.layers:
+ output = mod(
+ output,
+ src_mask=mask,
+ src_key_padding_mask=src_key_padding_mask,
+ )
+ layer_states.append(output[0])
+
+ if self.norm is not None:
+ output = self.norm(output)
+
+ return layer_states, output
+
+ output = src
+ for mod in self.layers:
+ output = mod(
+ output, src_mask=mask, src_key_padding_mask=src_key_padding_mask
+ )
+
+ if self.norm is not None:
+ output = self.norm(output)
+
+ return output
+
+ def infer(
+ self,
+ src: Tensor,
+ mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ return_layer_states: bool = False,
+ past_kv: Optional[Tensor] = None,
+ use_cache: bool = False,
+ ):
+ if past_kv is None:
+ past_length = 0
+ past_kv = tuple([None] * self.num_layers)
+ else:
+ past_length = past_kv[0][0].size(-2)
+ new_kv = () if use_cache else None
+ output = src
+ for mod, past_layer_kv in zip(self.layers, past_kv):
+ output, kv = mod.infer(
+ output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, past_kv=past_layer_kv, use_cache=use_cache
+ )
+ if use_cache:
+ new_kv = new_kv + (kv,)
+
+ if self.norm is not None:
+ output = self.norm(output)
+
+ return output, new_kv
+
+
+class TransformerDecoderLayer(nn.Module):
+ __constants__ = ["batch_first", "norm_first"]
+
+ def __init__(
+ self,
+ d_model: int,
+ nhead: int,
+ dim_feedforward: int = 2048,
+ dropout: float = 0.1,
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
+ linear1_self_attention_cls: nn.Module = nn.Linear,
+ linear2_self_attention_cls: nn.Module = nn.Linear,
+ linear1_feedforward_cls: nn.Module = nn.Linear,
+ linear2_feedforward_cls: nn.Module = nn.Linear,
+ batch_first: bool = False,
+ norm_first: bool = False,
+ device=None,
+ dtype=None,
+ layer_norm_cls: nn.Module = LayerNorm,
+ layer_norm_eps: float = 1e-5,
+ adaptive_layer_norm=False,
+ ) -> None:
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super(TransformerDecoderLayer, self).__init__()
+ self.self_attn = MultiheadAttention(
+ d_model,
+ nhead,
+ dropout=dropout,
+ batch_first=batch_first,
+ linear1_cls=linear1_self_attention_cls,
+ linear2_cls=linear2_self_attention_cls,
+ **factory_kwargs,
+ )
+ self.multihead_attn = MultiheadAttention(
+ d_model,
+ nhead,
+ dropout=dropout,
+ batch_first=batch_first,
+ linear1_cls=linear1_self_attention_cls,
+ linear2_cls=linear2_self_attention_cls,
+ **factory_kwargs,
+ )
+ # Implementation of Feedforward model
+ self.linear1 = linear1_feedforward_cls(
+ d_model, dim_feedforward, **factory_kwargs
+ )
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = linear2_feedforward_cls(
+ dim_feedforward, d_model, **factory_kwargs
+ )
+
+ self.norm_first = norm_first
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.dropout3 = nn.Dropout(dropout)
+
+ # Legacy string support for activation function.
+ if isinstance(activation, str):
+ self.activation = _get_activation_fn(activation)
+ elif isinstance(activation, partial):
+ self.activation = activation(d_model)
+ elif activation == BalancedDoubleSwish:
+ self.activation = BalancedDoubleSwish(d_model)
+ else:
+ self.activation = activation
+
+ if adaptive_layer_norm:
+ norm1 = layer_norm_cls(
+ d_model, eps=layer_norm_eps, **factory_kwargs
+ )
+ norm2 = layer_norm_cls(
+ d_model, eps=layer_norm_eps, **factory_kwargs
+ )
+ norm3 = layer_norm_cls(
+ d_model, eps=layer_norm_eps, **factory_kwargs
+ )
+
+ self.norm1 = AdaptiveLayerNorm(d_model, norm1)
+ self.norm2 = AdaptiveLayerNorm(d_model, norm2)
+ self.norm3 = AdaptiveLayerNorm(d_model, norm3)
+ else:
+ self.norm1 = layer_norm_cls(
+ d_model, eps=layer_norm_eps, **factory_kwargs
+ )
+ self.norm2 = layer_norm_cls(
+ d_model, eps=layer_norm_eps, **factory_kwargs
+ )
+ if layer_norm_cls == IdentityNorm:
+ self.norm3 = BalancedBasicNorm(
+ d_model, eps=layer_norm_eps, **factory_kwargs
+ )
+ else:
+ self.norm3 = layer_norm_cls(
+ d_model, eps=layer_norm_eps, **factory_kwargs
+ )
+
+ def forward(
+ self,
+ tgt: Tensor,
+ memory: Tensor,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ memory_key_padding_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ r"""Pass the inputs (and mask) through the decoder layer.
+
+ Args:
+ tgt: the sequence to the decoder layer (required).
+ memory: the sequence from the last layer of the encoder (required).
+ tgt_mask: the mask for the tgt sequence (optional).
+ memory_mask: the mask for the memory sequence (optional).
+ tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
+ memory_key_padding_mask: the mask for the memory keys per batch (optional).
+
+ Shape:
+ see the docs in Transformer class.
+ """
+ tgt_is_tuple = False
+ if isinstance(tgt, tuple):
+ x, stage_embedding = tgt
+ tgt_is_tuple = True
+ else:
+ x, stage_embedding = tgt, None
+
+ if self.norm_first:
+ x = x + self._sa_block(
+ self.norm1(x, stage_embedding), tgt_mask, tgt_key_padding_mask
+ )
+ x = x + self._mha_block(
+ self.norm2(x, stage_embedding),
+ memory,
+ memory_mask,
+ memory_key_padding_mask,
+ )
+ x = x + self._ff_block(self.norm3(x, stage_embedding))
+ else:
+ x = self.norm1(
+ x + self._sa_block(x, tgt_mask, tgt_key_padding_mask),
+ stage_embedding,
+ )
+ x = self.norm2(
+ x
+ + self._mha_block(
+ x, memory, memory_mask, memory_key_padding_mask
+ ),
+ stage_embedding,
+ )
+ x = self.norm3(x + self._ff_block(x), stage_embedding)
+
+ if tgt_is_tuple:
+ return (x, stage_embedding)
+ return x
+
+ # self-attention block
+ def _sa_block(
+ self,
+ x: Tensor,
+ attn_mask: Optional[Tensor],
+ key_padding_mask: Optional[Tensor],
+ ) -> Tensor:
+ x = self.self_attn(
+ x,
+ x,
+ x,
+ attn_mask=attn_mask,
+ key_padding_mask=key_padding_mask,
+ need_weights=False,
+ )[0]
+ return self.dropout1(x)
+
+ # multihead attention block
+ def _mha_block(
+ self,
+ x: Tensor,
+ mem: Tensor,
+ attn_mask: Optional[Tensor],
+ key_padding_mask: Optional[Tensor],
+ ) -> Tensor:
+ x = self.multihead_attn(
+ x,
+ mem,
+ mem,
+ attn_mask=attn_mask,
+ key_padding_mask=key_padding_mask,
+ need_weights=False,
+ )[0]
+ return self.dropout2(x)
+
+ # feed forward block
+ def _ff_block(self, x: Tensor) -> Tensor:
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
+ return self.dropout3(x)
+
+
+def _get_clones(module, N):
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
+ if activation == "relu":
+ return F.relu
+ elif activation == "gelu":
+ return F.gelu
+
+ raise RuntimeError(
+ "activation should be relu/gelu, not {}".format(activation)
+ )
diff --git a/presets/acou_1.npz b/presets/acou_1.npz
new file mode 100644
index 0000000000000000000000000000000000000000..f6c51bd1c0a5dc6eebcf3c63c17c05d1d612f6ff
--- /dev/null
+++ b/presets/acou_1.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:470ce66fc24a2d14e162343381f7d93ef0a3af51edf5fd37240c21f492b4e769
+size 15650
diff --git a/presets/acou_2.npz b/presets/acou_2.npz
new file mode 100644
index 0000000000000000000000000000000000000000..1e055e2639e010f57e74d11cd37d134f8d5ee05e
--- /dev/null
+++ b/presets/acou_2.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ec1c5328751cadeed5356d4264759799ad96d33ea8dd4f8a3d0a80dd8ddb0e74
+size 15426
diff --git a/presets/acou_3.npz b/presets/acou_3.npz
new file mode 100644
index 0000000000000000000000000000000000000000..1eb6978a203b4df5124bf745c1fde591d1864ce7
--- /dev/null
+++ b/presets/acou_3.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:03f241b094a32b3f542e74374183c6d15e8b70ae73ceeafb11bfd4ee6b8b4a3a
+size 15410
diff --git a/presets/acou_4.npz b/presets/acou_4.npz
new file mode 100644
index 0000000000000000000000000000000000000000..c0e623ffed42dd0fd089e928a79eeb25721ba6d3
--- /dev/null
+++ b/presets/acou_4.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:52b96f32863f13f84cf7ac4a27d2bc95cea70c350a037f4d1890b20b8da9501e
+size 15506
diff --git a/presets/alan.npz b/presets/alan.npz
new file mode 100644
index 0000000000000000000000000000000000000000..156afabccea6548d1a1e65e4b5bb95b6b85e493f
--- /dev/null
+++ b/presets/alan.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:28838c3f0b2f9f315b34e9b940f30641306f0cadc5c527857cd1cc408547ed1c
+size 50002
diff --git a/presets/amused.npz b/presets/amused.npz
new file mode 100644
index 0000000000000000000000000000000000000000..3d9b45ee3d7e557bb754d6564312479b92acf5fc
--- /dev/null
+++ b/presets/amused.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:df3e882f3a62805b9aaf300d81822cd4eddeafee480503b7b78e32be2085fb11
+size 20882
diff --git a/presets/anger.npz b/presets/anger.npz
new file mode 100644
index 0000000000000000000000000000000000000000..26477928feb6c7da2b0bb3b29ba3122adf2a000e
--- /dev/null
+++ b/presets/anger.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:959cec6dc0b30219db0d70cdd165fe00bbdc098165cf9d67ccdd1ecf7a5da5be
+size 22090
diff --git a/presets/babara.npz b/presets/babara.npz
new file mode 100644
index 0000000000000000000000000000000000000000..9a484d8b9a6ad6a907e426eccda7b0a4e6e8884e
--- /dev/null
+++ b/presets/babara.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8106b2a98c3f70587f23ab46ed5bf73b1c9a770481c3620ab140bd3256010376
+size 11526
diff --git a/presets/bronya_1.npz b/presets/bronya_1.npz
new file mode 100644
index 0000000000000000000000000000000000000000..361939a93a9fd2c00c775bb761f4a8afd9d226a9
--- /dev/null
+++ b/presets/bronya_1.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:02eaada2c3d58866c813887ed9f871587ef5a7e976abc23382ce46a17b208001
+size 18106
diff --git a/presets/cafe.npz b/presets/cafe.npz
new file mode 100644
index 0000000000000000000000000000000000000000..70b20f6e09decc37226a4af477a0757110c04224
--- /dev/null
+++ b/presets/cafe.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d78d96f5829da8f69c327ff25958da5b451305fdc9c308f7e67f13cf8d640fea
+size 22442
diff --git a/presets/dingzhen.npz b/presets/dingzhen.npz
new file mode 100644
index 0000000000000000000000000000000000000000..4da9178da67661edeb4868d9e251b016db846511
--- /dev/null
+++ b/presets/dingzhen.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4d19167c65eefef5e42dfaa1919ff5149ca0a93cb052396a47d1f42f9865f5f8
+size 18154
diff --git a/presets/dingzhen_1.npz b/presets/dingzhen_1.npz
new file mode 100644
index 0000000000000000000000000000000000000000..4da9178da67661edeb4868d9e251b016db846511
--- /dev/null
+++ b/presets/dingzhen_1.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4d19167c65eefef5e42dfaa1919ff5149ca0a93cb052396a47d1f42f9865f5f8
+size 18154
diff --git a/presets/disgust.npz b/presets/disgust.npz
new file mode 100644
index 0000000000000000000000000000000000000000..fa775736b826d61213653a808855eaf8d263c61d
--- /dev/null
+++ b/presets/disgust.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4443f0a395072700f2ec6101dbf2ad9d28968aa3e5809e384ea131832f894d7f
+size 39386
diff --git a/presets/emo_amused.npz b/presets/emo_amused.npz
new file mode 100644
index 0000000000000000000000000000000000000000..545712470a78ae6b3f91308779b612c9b8ef33b4
--- /dev/null
+++ b/presets/emo_amused.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:38be2ea16dc79beae68b6c885d99d4dad516acbd88ed5ed6991dd97301f2f30b
+size 15378
diff --git a/presets/emo_anger.npz b/presets/emo_anger.npz
new file mode 100644
index 0000000000000000000000000000000000000000..8cbf61bb2353db8a1337debe68e6c5113099fe46
--- /dev/null
+++ b/presets/emo_anger.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3261c3bdd5b7b4be9783d9293ee3d871be9d9d791f2b3a8bf62a1a0ee0ed93e6
+size 15434
diff --git a/presets/emo_neutral.npz b/presets/emo_neutral.npz
new file mode 100644
index 0000000000000000000000000000000000000000..ce1da3b25448c86b3ec2b2d2d0f19c56bca789c8
--- /dev/null
+++ b/presets/emo_neutral.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2188c4154692316ed7c0edee3aa3dd8678be36f355ee2b8c8a3a6412c3673ba9
+size 15578
diff --git a/presets/emo_sleepy.npz b/presets/emo_sleepy.npz
new file mode 100644
index 0000000000000000000000000000000000000000..b39ef24ea839f0a67663610473c2026751b96a72
--- /dev/null
+++ b/presets/emo_sleepy.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2a53255890beaf4ed339e1967f0837fdb87c34c9f7e18bf77cd4b08eba176963
+size 15370
diff --git a/presets/emotion_sleepiness.npz b/presets/emotion_sleepiness.npz
new file mode 100644
index 0000000000000000000000000000000000000000..5b6bfc27f36658c0f62272ce30f357fec5911f97
--- /dev/null
+++ b/presets/emotion_sleepiness.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e0f866a278a10c7b6b494fb62589a9d8fef778ccf272df3b0d5510f45b243b5c
+size 33218
diff --git a/presets/en2zh_tts_1.npz b/presets/en2zh_tts_1.npz
new file mode 100644
index 0000000000000000000000000000000000000000..e73db03e27078932694dfdb6df5cc849c6bcc3d7
--- /dev/null
+++ b/presets/en2zh_tts_1.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5d4de4ed055448ea54f7b40091afae565197f960d954279035ac537ea5a01bc4
+size 44354
diff --git a/presets/en2zh_tts_2.npz b/presets/en2zh_tts_2.npz
new file mode 100644
index 0000000000000000000000000000000000000000..d15ad2188a0f5fead60165d86c825dec7a914ac2
--- /dev/null
+++ b/presets/en2zh_tts_2.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dcc066ea104daa27d1552fe76574d09359d56fa892241581cc19e931a696eca9
+size 24178
diff --git a/presets/en2zh_tts_3.npz b/presets/en2zh_tts_3.npz
new file mode 100644
index 0000000000000000000000000000000000000000..f0aa9306b71c23cfadfd6eae0bb0b7a84084fade
--- /dev/null
+++ b/presets/en2zh_tts_3.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7468944e6d0ed7f2da033e8037be07dbafc76bd1ed7c0f5996d85ff45aacda11
+size 21410
diff --git a/presets/en2zh_tts_4.npz b/presets/en2zh_tts_4.npz
new file mode 100644
index 0000000000000000000000000000000000000000..b52465fadebb7f7f163a26f2e9d9633f703ad039
--- /dev/null
+++ b/presets/en2zh_tts_4.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0fd8d0914e74769114310e9504d68d6b7b0c6aacd46763478cbfd4f9631ad54a
+size 43826
diff --git a/presets/esta.npz b/presets/esta.npz
new file mode 100644
index 0000000000000000000000000000000000000000..4d75c2e4a934b61f824dc447c8592a5731025326
--- /dev/null
+++ b/presets/esta.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4f944e135d901a00e74e7affe6757334e9a2679c10ad7ae4bcb5b33569d77eba
+size 40250
diff --git a/presets/fuxuan_2.npz b/presets/fuxuan_2.npz
new file mode 100644
index 0000000000000000000000000000000000000000..aaeb7f8bc5af0680a2d64e452e1d029f592aa44b
--- /dev/null
+++ b/presets/fuxuan_2.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:17b90388d179ae309e1f577c28c3f10d9bed73c6ccbffdd829c00568eb3941e6
+size 50330
diff --git a/presets/librispeech_1.npz b/presets/librispeech_1.npz
new file mode 100644
index 0000000000000000000000000000000000000000..e2480cc12a6a526df5c552700f1507675cee62d8
--- /dev/null
+++ b/presets/librispeech_1.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:415b244e43b45291fd651d71f15bb7a31c244e2054988c436f6bbc04465c6099
+size 15650
diff --git a/presets/librispeech_2.npz b/presets/librispeech_2.npz
new file mode 100644
index 0000000000000000000000000000000000000000..0eed46188be3dea3293903a13daa718ab0c802c1
--- /dev/null
+++ b/presets/librispeech_2.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bd74e77370248b025321b9dbae25b1572f13f98da63255e384d382d2b0c78227
+size 15418
diff --git a/presets/librispeech_3.npz b/presets/librispeech_3.npz
new file mode 100644
index 0000000000000000000000000000000000000000..fbaa57d5d3c106ea9a77af43a6a2a3c0d3045773
--- /dev/null
+++ b/presets/librispeech_3.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1eceb3f4cc0f3a8856b5e3b5f1ca28c428d75305b1452da1ecf4013bc358ccaa
+size 15634
diff --git a/presets/librispeech_4.npz b/presets/librispeech_4.npz
new file mode 100644
index 0000000000000000000000000000000000000000..3516ee92a587b51c645856122a12503386f5dd28
--- /dev/null
+++ b/presets/librispeech_4.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3939dde39f5e65bc01f5eba9acb7b8329465aaca3c38edf1b240aa714e687960
+size 15594
diff --git a/presets/neutral.npz b/presets/neutral.npz
new file mode 100644
index 0000000000000000000000000000000000000000..6af010decf0d7459e76a0764a6495ecd9758c524
--- /dev/null
+++ b/presets/neutral.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a8a63993526ffdc788a711b512d07a8b1c816151a1edb63913d0bfb48c2ea380
+size 21050
diff --git a/presets/paimon_1.npz b/presets/paimon_1.npz
new file mode 100644
index 0000000000000000000000000000000000000000..8e9cf23f35e99a3791ea54ac8f0700dd188d9db5
--- /dev/null
+++ b/presets/paimon_1.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:452d5e0cd3a060db521bd65a16af818a6177f357801402aa5581eceb2c24039a
+size 13762
diff --git a/presets/prompt_1.npz b/presets/prompt_1.npz
new file mode 100644
index 0000000000000000000000000000000000000000..a531db4affda6ffa8f7a0b997e1e6840fd87fe4b
--- /dev/null
+++ b/presets/prompt_1.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2bd0e41e72e657bdf9c6ceaea0294807faea2db623a0e33b39e1a8eebcf4d21c
+size 87338
diff --git a/presets/rosalia.npz b/presets/rosalia.npz
new file mode 100644
index 0000000000000000000000000000000000000000..800162152c8207d2c491b8c4018bf177ab6f8c8a
--- /dev/null
+++ b/presets/rosalia.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:af87ebe283bbb7b527c6c0ff0a02a315416485677fe23330040c2766fa9af919
+size 11414
diff --git a/presets/seel.npz b/presets/seel.npz
new file mode 100644
index 0000000000000000000000000000000000000000..095b1754f23a1030296b2a8f8f90b230e4b6dc1e
--- /dev/null
+++ b/presets/seel.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:44ad2e900df3625f9753e949dc5a7d8479c4091e24cb18cbf46e34e29498d952
+size 13554
diff --git a/presets/seel_1.npz b/presets/seel_1.npz
new file mode 100644
index 0000000000000000000000000000000000000000..095b1754f23a1030296b2a8f8f90b230e4b6dc1e
--- /dev/null
+++ b/presets/seel_1.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:44ad2e900df3625f9753e949dc5a7d8479c4091e24cb18cbf46e34e29498d952
+size 13554
diff --git a/presets/sleepiness.npz b/presets/sleepiness.npz
new file mode 100644
index 0000000000000000000000000000000000000000..5b6bfc27f36658c0f62272ce30f357fec5911f97
--- /dev/null
+++ b/presets/sleepiness.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e0f866a278a10c7b6b494fb62589a9d8fef778ccf272df3b0d5510f45b243b5c
+size 33218
diff --git a/presets/vctk_1.npz b/presets/vctk_1.npz
new file mode 100644
index 0000000000000000000000000000000000000000..c23c917cdcc846bbd047edd409b182d236aa6d28
--- /dev/null
+++ b/presets/vctk_1.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1c9df2ea8c2bc919c0ac50f8e05950bb4e831de69b33a7fb12d584da5b2512f2
+size 15530
diff --git a/presets/vctk_2.npz b/presets/vctk_2.npz
new file mode 100644
index 0000000000000000000000000000000000000000..a671e453cd54cf7345c5a1199b70280f877dae0d
--- /dev/null
+++ b/presets/vctk_2.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cc84744435a304b3e700b8b1ab94c3b891db3056bd55a0f9dd99eff284016efa
+size 15458
diff --git a/presets/vctk_3.npz b/presets/vctk_3.npz
new file mode 100644
index 0000000000000000000000000000000000000000..1c045ead518d9f37699a0b59ebe57296e0542aef
--- /dev/null
+++ b/presets/vctk_3.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ec0d528c6ae9c8f32b02ca6b57aa565b9fe63f401fd04f2632ed7e536699b9ac
+size 15450
diff --git a/presets/vctk_4.npz b/presets/vctk_4.npz
new file mode 100644
index 0000000000000000000000000000000000000000..1fbfbbdd4ef4e292e24f7276defadaefdcf0e98b
--- /dev/null
+++ b/presets/vctk_4.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8ff2b71254ae00be6e42ad206c7616d168bd41582837e9eeb4d6cd669bd0b140
+size 15330
diff --git a/presets/yaesakura.npz b/presets/yaesakura.npz
new file mode 100644
index 0000000000000000000000000000000000000000..3f6b151870c881c61eb232dbb28c1403a67532df
--- /dev/null
+++ b/presets/yaesakura.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b388a18d286b4ba13d45bae373a716c0010dc40ae9c940d53b5a04cbc64e95ff
+size 12442
diff --git a/presets/yaesakura_1.npz b/presets/yaesakura_1.npz
new file mode 100644
index 0000000000000000000000000000000000000000..3f6b151870c881c61eb232dbb28c1403a67532df
--- /dev/null
+++ b/presets/yaesakura_1.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b388a18d286b4ba13d45bae373a716c0010dc40ae9c940d53b5a04cbc64e95ff
+size 12442
diff --git a/presets/zh2en_tts_1.npz b/presets/zh2en_tts_1.npz
new file mode 100644
index 0000000000000000000000000000000000000000..bbd2a9c750af5b6cac656b01ef36c2dd3ee766f7
--- /dev/null
+++ b/presets/zh2en_tts_1.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:07bff150ad145f9b06f0e7cbf9b0ee4d9e926600efa0d129bd831c8b2993c2b0
+size 23546
diff --git a/presets/zh2en_tts_2.npz b/presets/zh2en_tts_2.npz
new file mode 100644
index 0000000000000000000000000000000000000000..644f6cf976b91b284316a5e4513b72980d7557a8
--- /dev/null
+++ b/presets/zh2en_tts_2.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0257d0782578c7813c3f43b5e93c0e681f9ea42fe76775d5a4f4fea64609b03e
+size 20170
diff --git a/presets/zh2en_tts_3.npz b/presets/zh2en_tts_3.npz
new file mode 100644
index 0000000000000000000000000000000000000000..fe2ce9d14ae1af4ee307d1b0a109c141141957d9
--- /dev/null
+++ b/presets/zh2en_tts_3.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5da48e060d15f391767bffe1d528bfbc782a562413feed2e9bd2cafa82bf644a
+size 17906
diff --git a/presets/zh2en_tts_4.npz b/presets/zh2en_tts_4.npz
new file mode 100644
index 0000000000000000000000000000000000000000..693e32dc6f27b91270c8c466b1a6671fb0ed7054
--- /dev/null
+++ b/presets/zh2en_tts_4.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bda7a70ed9b03d8f1ff99d2444ea1df476a8deaf75633aa3b3f6cf3f45ae7e5e
+size 33682
diff --git a/prompts/en-1.wav b/prompts/en-1.wav
new file mode 100644
index 0000000000000000000000000000000000000000..d4ea9ce79a35e86b9b45f536bdab6c548c12917e
Binary files /dev/null and b/prompts/en-1.wav differ
diff --git a/prompts/en-2.wav b/prompts/en-2.wav
new file mode 100644
index 0000000000000000000000000000000000000000..d0d259eed7c223e74a9e58f8a05c385bf4553461
Binary files /dev/null and b/prompts/en-2.wav differ
diff --git a/prompts/ja-1.wav b/prompts/ja-1.wav
new file mode 100644
index 0000000000000000000000000000000000000000..6b8cd477f860d85f8e87972a2f69b2a6be5e80f2
Binary files /dev/null and b/prompts/ja-1.wav differ
diff --git a/prompts/ja-2.ogg b/prompts/ja-2.ogg
new file mode 100644
index 0000000000000000000000000000000000000000..e7277d63d418a633d94c659710ddae0befa2587b
Binary files /dev/null and b/prompts/ja-2.ogg differ
diff --git a/prompts/ph.txt b/prompts/ph.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/prompts/zh-1.wav b/prompts/zh-1.wav
new file mode 100644
index 0000000000000000000000000000000000000000..aa7b52fdc5e7b965cc00fd7db90f60299a54e6b3
Binary files /dev/null and b/prompts/zh-1.wav differ
diff --git a/prompts/zh-2.wav b/prompts/zh-2.wav
new file mode 100644
index 0000000000000000000000000000000000000000..770a6bb330a84c38828b9b873f8cc82f8eb25f44
Binary files /dev/null and b/prompts/zh-2.wav differ
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b08c440376a2d744b7fd3eb50fd100a6074f0324
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,26 @@
+soundfile
+numpy
+torch==2.0.1
+torchvision==0.15.2
+torchaudio
+tokenizers
+encodec
+vocos
+langid
+unidecode
+pyopenjtalk
+pypinyin
+inflect
+cn2an
+jieba
+eng_to_ipa
+jieba
+SudachiPy
+sudachidict_core
+nltk
+openai-whisper
+phonemizer
+matplotlib
+psutil
+transformers
+gradio
\ No newline at end of file
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f679b50adebcabe109af22d6f4596634b8da8fc1
--- /dev/null
+++ b/utils/__init__.py
@@ -0,0 +1,15 @@
+import torch
+import torch.nn as nn
+# from icefall.utils import make_pad_mask
+
+from .symbol_table import SymbolTable
+
+# make_pad_mask = make_pad_mask
+SymbolTable = SymbolTable
+
+
+class Transpose(nn.Identity):
+ """(N, T, D) -> (N, D, T)"""
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ return input.transpose(1, 2)
diff --git a/utils/__pycache__/__init__.cpython-38.pyc b/utils/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fe3ce17582a9b3e70aa7dcad5bf2dfd6b7fadead
Binary files /dev/null and b/utils/__pycache__/__init__.cpython-38.pyc differ
diff --git a/utils/download.py b/utils/download.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ab5cd9394e8d9f2793e2e2eb8d3740c2cc2ee0b
--- /dev/null
+++ b/utils/download.py
@@ -0,0 +1,49 @@
+import sys
+import requests
+
+
+def download_file_from_google_drive(id, destination):
+ URL = "https://docs.google.com/uc?export=download&confirm=1"
+
+ session = requests.Session()
+
+ response = session.get(URL, params={"id": id}, stream=True)
+ token = get_confirm_token(response)
+
+ if token:
+ params = {"id": id, "confirm": token}
+ response = session.get(URL, params=params, stream=True)
+
+ save_response_content(response, destination)
+
+
+def get_confirm_token(response):
+ for key, value in response.cookies.items():
+ if key.startswith("download_warning"):
+ return value
+
+ return None
+
+
+def save_response_content(response, destination):
+ CHUNK_SIZE = 32768
+
+ with open(destination, "wb") as f:
+ for chunk in response.iter_content(CHUNK_SIZE):
+ if chunk: # filter out keep-alive new chunks
+ f.write(chunk)
+
+
+def main():
+ if len(sys.argv) >= 3:
+ file_id = sys.argv[1]
+ destination = sys.argv[2]
+ else:
+ file_id = "TAKE_ID_FROM_SHAREABLE_LINK"
+ destination = "DESTINATION_FILE_ON_YOUR_DISK"
+ print(f"dowload {file_id} to {destination}")
+ download_file_from_google_drive(file_id, destination)
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/utils/g2p/__init__.py b/utils/g2p/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6da9152cd58393f39937085139ee36d55ca7367
--- /dev/null
+++ b/utils/g2p/__init__.py
@@ -0,0 +1,72 @@
+""" from https://github.com/keithito/tacotron """
+import utils.g2p.cleaners
+from utils.g2p.symbols import symbols
+from tokenizers import Tokenizer
+
+# Mappings from symbol to numeric ID and vice versa:
+_symbol_to_id = {s: i for i, s in enumerate(symbols)}
+_id_to_symbol = {i: s for i, s in enumerate(symbols)}
+
+
+class PhonemeBpeTokenizer:
+ def __init__(self, tokenizer_path = "./utils/g2p/bpe_1024.json"):
+ self.tokenizer = Tokenizer.from_file(tokenizer_path)
+
+ def tokenize(self, text):
+ # 1. convert text to phoneme
+ phonemes, langs = _clean_text(text, ['cje_cleaners'])
+ # 2. replace blank space " " with "_"
+ phonemes = phonemes.replace(" ", "_")
+ # 3. tokenize phonemes
+ phoneme_tokens = self.tokenizer.encode(phonemes).ids
+ assert(len(phoneme_tokens) == len(langs))
+ if not len(phoneme_tokens):
+ raise ValueError("Empty text is given")
+ return phoneme_tokens, langs
+
+def text_to_sequence(text, cleaner_names):
+ '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
+ Args:
+ text: string to convert to a sequence
+ cleaner_names: names of the cleaner functions to run the text through
+ Returns:
+ List of integers corresponding to the symbols in the text
+ '''
+ sequence = []
+ symbol_to_id = {s: i for i, s in enumerate(symbols)}
+ clean_text = _clean_text(text, cleaner_names)
+ for symbol in clean_text:
+ if symbol not in symbol_to_id.keys():
+ continue
+ symbol_id = symbol_to_id[symbol]
+ sequence += [symbol_id]
+ return sequence
+
+
+def cleaned_text_to_sequence(cleaned_text):
+ '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
+ Args:
+ text: string to convert to a sequence
+ Returns:
+ List of integers corresponding to the symbols in the text
+ '''
+ sequence = [_symbol_to_id[symbol] for symbol in cleaned_text if symbol in _symbol_to_id.keys()]
+ return sequence
+
+
+def sequence_to_text(sequence):
+ '''Converts a sequence of IDs back to a string'''
+ result = ''
+ for symbol_id in sequence:
+ s = _id_to_symbol[symbol_id]
+ result += s
+ return result
+
+
+def _clean_text(text, cleaner_names):
+ for name in cleaner_names:
+ cleaner = getattr(cleaners, name)
+ if not cleaner:
+ raise Exception('Unknown cleaner: %s' % name)
+ text, langs = cleaner(text)
+ return text, langs
diff --git a/utils/g2p/__pycache__/__init__.cpython-38.pyc b/utils/g2p/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9f5652536aec5cb4909f4b8c4d6f82ef784263f4
Binary files /dev/null and b/utils/g2p/__pycache__/__init__.cpython-38.pyc differ
diff --git a/utils/g2p/__pycache__/cleaners.cpython-38.pyc b/utils/g2p/__pycache__/cleaners.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8a0a76161b1a4d572b0b265e5a8b10a0dca9e4e6
Binary files /dev/null and b/utils/g2p/__pycache__/cleaners.cpython-38.pyc differ
diff --git a/utils/g2p/__pycache__/english.cpython-38.pyc b/utils/g2p/__pycache__/english.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0e4acb6d4ef61dffaaac2ee7068dea87565667c3
Binary files /dev/null and b/utils/g2p/__pycache__/english.cpython-38.pyc differ
diff --git a/utils/g2p/__pycache__/japanese.cpython-38.pyc b/utils/g2p/__pycache__/japanese.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6c9d9e630232f20ea50434d8d1cab07658eb5b70
Binary files /dev/null and b/utils/g2p/__pycache__/japanese.cpython-38.pyc differ
diff --git a/utils/g2p/__pycache__/mandarin.cpython-38.pyc b/utils/g2p/__pycache__/mandarin.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b72016aa9b874286d2b4e824343406d91db2c773
Binary files /dev/null and b/utils/g2p/__pycache__/mandarin.cpython-38.pyc differ
diff --git a/utils/g2p/__pycache__/symbols.cpython-38.pyc b/utils/g2p/__pycache__/symbols.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..922be9f846d866df964da05509553e79476cc119
Binary files /dev/null and b/utils/g2p/__pycache__/symbols.cpython-38.pyc differ
diff --git a/utils/g2p/bpe_1024.json b/utils/g2p/bpe_1024.json
new file mode 100644
index 0000000000000000000000000000000000000000..19331439acfb9d16944399e986b0aee38c95758c
--- /dev/null
+++ b/utils/g2p/bpe_1024.json
@@ -0,0 +1,2049 @@
+{
+ "version": "1.0",
+ "truncation": null,
+ "padding": null,
+ "added_tokens": [
+ {
+ "id": 0,
+ "content": "[UNK]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false,
+ "special": true
+ },
+ {
+ "id": 1,
+ "content": "[CLS]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false,
+ "special": true
+ },
+ {
+ "id": 2,
+ "content": "[SEP]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false,
+ "special": true
+ },
+ {
+ "id": 3,
+ "content": "[PAD]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false,
+ "special": true
+ },
+ {
+ "id": 4,
+ "content": "[MASK]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false,
+ "special": true
+ }
+ ],
+ "normalizer": null,
+ "pre_tokenizer": {
+ "type": "Whitespace"
+ },
+ "post_processor": null,
+ "decoder": null,
+ "model": {
+ "type": "BPE",
+ "dropout": null,
+ "unk_token": "[UNK]",
+ "continuing_subword_prefix": null,
+ "end_of_word_suffix": null,
+ "fuse_unk": false,
+ "byte_fallback": false,
+ "vocab": {
+ "[UNK]": 0,
+ "[CLS]": 1,
+ "[SEP]": 2,
+ "[PAD]": 3,
+ "[MASK]": 4,
+ "!": 5,
+ "#": 6,
+ "*": 7,
+ ",": 8,
+ "-": 9,
+ ".": 10,
+ "=": 11,
+ "?": 12,
+ "N": 13,
+ "Q": 14,
+ "^": 15,
+ "_": 16,
+ "`": 17,
+ "a": 18,
+ "b": 19,
+ "d": 20,
+ "e": 21,
+ "f": 22,
+ "g": 23,
+ "h": 24,
+ "i": 25,
+ "j": 26,
+ "k": 27,
+ "l": 28,
+ "m": 29,
+ "n": 30,
+ "o": 31,
+ "p": 32,
+ "s": 33,
+ "t": 34,
+ "u": 35,
+ "v": 36,
+ "w": 37,
+ "x": 38,
+ "y": 39,
+ "z": 40,
+ "~": 41,
+ "æ": 42,
+ "ç": 43,
+ "ð": 44,
+ "ŋ": 45,
+ "ɑ": 46,
+ "ɔ": 47,
+ "ə": 48,
+ "ɛ": 49,
+ "ɥ": 50,
+ "ɪ": 51,
+ "ɫ": 52,
+ "ɯ": 53,
+ "ɸ": 54,
+ "ɹ": 55,
+ "ɾ": 56,
+ "ʃ": 57,
+ "ʊ": 58,
+ "ʑ": 59,
+ "ʒ": 60,
+ "ʰ": 61,
+ "ˈ": 62,
+ "ˌ": 63,
+ "θ": 64,
+ "…": 65,
+ "⁼": 66,
+ "↑": 67,
+ "→": 68,
+ "↓": 69,
+ "_t": 70,
+ "↓↑": 71,
+ "_ˈ": 72,
+ "ən": 73,
+ "_s": 74,
+ "aɪ": 75,
+ "əɹ": 76,
+ "eɪ": 77,
+ "oʊ": 78,
+ "_k": 79,
+ "ʃi": 80,
+ "_w": 81,
+ "_ð": 82,
+ "ts": 83,
+ "tʃ": 84,
+ "_ts": 85,
+ "_h": 86,
+ "_ə": 87,
+ "_m": 88,
+ "an": 89,
+ "_n": 90,
+ "_ðə": 91,
+ "ɛn": 92,
+ "ɑʊ": 93,
+ "ɑŋ": 94,
+ "`⁼": 95,
+ "_p": 96,
+ "_i": 97,
+ "_ɪ": 98,
+ "_tʃ": 99,
+ "_l": 100,
+ "jɛn": 101,
+ "_d": 102,
+ "_f": 103,
+ "_j": 104,
+ "wo": 105,
+ "_b": 106,
+ "ta": 107,
+ "`↓": 108,
+ "te": 109,
+ "ənd": 110,
+ "_ʃi": 111,
+ "wa": 112,
+ "ka": 113,
+ "ɪŋ": 114,
+ "in": 115,
+ "st": 116,
+ "li": 117,
+ "ʊŋ": 118,
+ "_tɪ": 119,
+ "to": 120,
+ "weɪ": 121,
+ "_ənd": 122,
+ "ʰi": 123,
+ "_əv": 124,
+ "əŋ": 125,
+ "no": 126,
+ "_x": 127,
+ "ɾɯ": 128,
+ "na": 129,
+ "_a": 130,
+ "_ɹ": 131,
+ "ɪn": 132,
+ "ga": 133,
+ "de": 134,
+ "joʊ": 135,
+ "æn": 136,
+ "kɯ": 137,
+ "ɾe": 138,
+ "ma": 139,
+ "_ðə_ˈ": 140,
+ "ɾa": 141,
+ "ɛɹ": 142,
+ "mo": 143,
+ "ɔɹ": 144,
+ "əɫ": 145,
+ "_g": 146,
+ "da": 147,
+ "*↑": 148,
+ "ɪˈ": 149,
+ "_o": 150,
+ "_ʃ": 151,
+ "iŋ": 152,
+ "ja": 153,
+ "əm": 154,
+ "_ˌ": 155,
+ "aʊ": 156,
+ "_əˈ": 157,
+ "`↑": 158,
+ "ət": 159,
+ "_aɪ": 160,
+ "oo": 161,
+ "sɯ": 162,
+ "↓.": 163,
+ "_ɪn": 164,
+ "_hi": 165,
+ "_wɪ": 166,
+ "ɪz": 167,
+ "_na": 168,
+ "wan": 169,
+ "_ko": 170,
+ "_wo": 171,
+ "ɪd": 172,
+ "ɾi": 173,
+ "_ju": 174,
+ "mə": 175,
+ "_lə": 176,
+ "_hæ": 177,
+ "_ðət": 178,
+ "ɑɹ": 179,
+ "tʰ": 180,
+ "ki": 181,
+ "……": 182,
+ "ɑz": 183,
+ "_ɔ": 184,
+ "_mi": 185,
+ "_wɑz": 186,
+ "_ˈs": 187,
+ "↓,": 188,
+ "_tʰ": 189,
+ "əˈ": 190,
+ "dʑ": 191,
+ "ɪt": 192,
+ "_kʰ": 193,
+ "iɛ": 194,
+ "_ma": 195,
+ "ɪs": 196,
+ "tsɯ": 197,
+ "_ni": 198,
+ "_ɪt": 199,
+ "ke": 200,
+ "iɑʊ": 201,
+ "_ka": 202,
+ "_əɹ": 203,
+ "nd": 204,
+ "_ˈp": 205,
+ "ko": 206,
+ "jo": 207,
+ "ɹi": 208,
+ "mən": 209,
+ "ʊd": 210,
+ "_ˈm": 211,
+ "_fəɹ": 212,
+ "tʃʰi": 213,
+ "sa": 214,
+ "ʰɥ": 215,
+ "kʰ": 216,
+ "ˈs": 217,
+ "ɑt": 218,
+ "ɛd": 219,
+ "se": 220,
+ "tʃi": 221,
+ "ɛɫ": 222,
+ "_ˈk": 223,
+ "_joʊ": 224,
+ "təɹ": 225,
+ "ɛz": 226,
+ "--": 227,
+ "vəɹ": 228,
+ "`→": 229,
+ "ʃən": 230,
+ "_ɪz": 231,
+ "_meɪ": 232,
+ "_æ": 233,
+ "dʒ": 234,
+ "_ki": 235,
+ "_hɪz": 236,
+ "_bi": 237,
+ "uɑŋ": 238,
+ "_ˈf": 239,
+ "↓↑.": 240,
+ "_wɪθ": 241,
+ "ju": 242,
+ "iɑŋ": 243,
+ "→.": 244,
+ "_so": 245,
+ "_həɹ": 246,
+ "↑.": 247,
+ "ni": 248,
+ "_mo": 249,
+ "_maɪ": 250,
+ "laɪ": 251,
+ "ɥɛ": 252,
+ "_ta": 253,
+ "ənt": 254,
+ "_tʃʰi": 255,
+ "_sɯ": 256,
+ "_θ": 257,
+ "_ɛz": 258,
+ "wən": 259,
+ "me": 260,
+ "mi": 261,
+ "_hæd": 262,
+ "_ha": 263,
+ "əs": 264,
+ "_ˈl": 265,
+ "_st": 266,
+ "ðəɹ": 267,
+ "oʊn": 268,
+ "_wa": 269,
+ "ʰəŋ": 270,
+ "_nɑt": 271,
+ "*.": 272,
+ "kt": 273,
+ "_ˈh": 274,
+ "do": 275,
+ "ɥæn": 276,
+ "ne": 277,
+ "_to": 278,
+ "_wən": 279,
+ "_no": 280,
+ "_laɪ": 281,
+ "_wəɹ": 282,
+ "↑,": 283,
+ "→,": 284,
+ "ɛs": 285,
+ "↓↑,": 286,
+ "_ɔn": 287,
+ "ʰu": 288,
+ "so": 289,
+ "_ˈb": 290,
+ "ɫd": 291,
+ "ɪk": 292,
+ "ɪst": 293,
+ "_fɹ": 294,
+ "_ðɛɹ": 295,
+ "_weɪ": 296,
+ "kaɾa": 297,
+ "_ˈd": 298,
+ "_hæv": 299,
+ "tsʰ": 300,
+ "waɪ": 301,
+ "ɾo": 302,
+ "ɛm": 303,
+ "_æt": 304,
+ "ʊɹ": 305,
+ "_ˈw": 306,
+ "ba": 307,
+ "_noʊ": 308,
+ "ʰjɛn": 309,
+ "ɹeɪ": 310,
+ "_jo": 311,
+ "ɸɯ": 312,
+ "_sa": 313,
+ "_ɹɪˈ": 314,
+ "_ˈn": 315,
+ "ai": 316,
+ "_bət": 317,
+ "ɪɹ": 318,
+ "tʃʰɥ": 319,
+ "_dʑ": 320,
+ "əˌ": 321,
+ "_ðɪs": 322,
+ "..": 323,
+ "xwa": 324,
+ "_ɪm": 325,
+ "_dɪˈ": 326,
+ "_kən": 327,
+ "dʑi": 328,
+ "*,": 329,
+ "ɑn": 330,
+ "_ʃiɑŋ": 331,
+ "_kɯ": 332,
+ "ʃin": 333,
+ "_soʊ": 334,
+ "bi": 335,
+ "tʰjɛn": 336,
+ "te_i": 337,
+ "_tsʰ": 338,
+ "_ɯ": 339,
+ "aɪt": 340,
+ "ʰiŋ": 341,
+ "ðə": 342,
+ "_ɔɫ": 343,
+ "_ˈɹ": 344,
+ "nai": 345,
+ "əɹd": 346,
+ "_ˈt": 347,
+ "_ən": 348,
+ "_tʃʰɥ": 349,
+ "_iɛ": 350,
+ "leɪ": 351,
+ "ɛɹi": 352,
+ "ˈt": 353,
+ "ha": 354,
+ "ʃiŋ": 355,
+ "ɛvəɹ": 356,
+ "zɯ": 357,
+ "_wi": 358,
+ "_ja": 359,
+ "ɛk": 360,
+ "ʰɑŋ": 361,
+ "_tsɯ": 362,
+ "_əv_ðə": 363,
+ "taʃi": 364,
+ "_sɛd": 365,
+ "_xə": 366,
+ "_li": 367,
+ "_si": 368,
+ "desɯ": 369,
+ "_ˌɪn": 370,
+ "ʃjɛn": 371,
+ "_baɪ": 372,
+ "on": 373,
+ "_xɑʊ": 374,
+ "_ðeɪ": 375,
+ "_xaɪ": 376,
+ "`↓↑": 377,
+ "xweɪ": 378,
+ "hi": 379,
+ "_se": 380,
+ "ə_s": 381,
+ "_fɹəm": 382,
+ "ʊt": 383,
+ "di": 384,
+ "aʊt": 385,
+ "əb": 386,
+ "sɹ": 387,
+ "əz": 388,
+ "_xweɪ": 389,
+ "_kʰə": 390,
+ "ɹu": 391,
+ "_u": 392,
+ "_de": 393,
+ "aɪd": 394,
+ "ɪv": 395,
+ "bɯ": 396,
+ "_ho": 397,
+ "əɹz": 398,
+ "joo": 399,
+ "_bɪˈ": 400,
+ "_tʰa": 401,
+ "ɛt": 402,
+ "en": 403,
+ "ɛni": 404,
+ "əst": 405,
+ "æk": 406,
+ "ə_ts": 407,
+ "_ˈɪn": 408,
+ "ti": 409,
+ "ɥn": 410,
+ "_dʒ": 411,
+ "xɑʊ": 412,
+ "_ˈv": 413,
+ "ʃiɑŋ": 414,
+ "pʰ": 415,
+ "_wɪtʃ": 416,
+ "eɪm": 417,
+ "oʊz": 418,
+ "əðəɹ": 419,
+ "fɑŋ": 420,
+ "_ˈg": 421,
+ "_do": 422,
+ "_ʃiɑʊ": 423,
+ "_ˈæ": 424,
+ "_jʊɹ": 425,
+ "_ðɛm": 426,
+ "ɪm": 427,
+ "ɛst": 428,
+ "ænd": 429,
+ "_du": 430,
+ "ɯɯ": 431,
+ "kan": 432,
+ "_da": 433,
+ "ino": 434,
+ "_e": 435,
+ "_wʊd": 436,
+ "ɛnd": 437,
+ "meɪ": 438,
+ "θɪŋ": 439,
+ "_ʃjɛn": 440,
+ "iz": 441,
+ "aɪm": 442,
+ "_hu": 443,
+ "_əˈb": 444,
+ "əns": 445,
+ "_wɪɫ": 446,
+ "tʰi": 447,
+ "go": 448,
+ "ɛnt": 449,
+ "fu": 450,
+ "æp": 451,
+ "xoʊ": 452,
+ "eɪk": 453,
+ "ʊk": 454,
+ "əɹˈ": 455,
+ "_θɪŋ": 456,
+ "əl": 457,
+ "pɹ": 458,
+ "ətʃ": 459,
+ "nt": 460,
+ "_ɸɯ": 461,
+ "lu": 462,
+ "_ˈɔ": 463,
+ "_iɑʊ": 464,
+ "lə": 465,
+ "tu": 466,
+ "_dʑi": 467,
+ "eɪt": 468,
+ "_ʃin": 469,
+ "nna": 470,
+ "_ˈpɹ": 471,
+ "fən": 472,
+ "_əp": 473,
+ "njɛn": 474,
+ "_aʊt": 475,
+ "fɔɹ": 476,
+ "_tu": 477,
+ "eɪʃən": 478,
+ "ɪɫ": 479,
+ "_wət": 480,
+ "_ɪf": 481,
+ "_ɥ": 482,
+ "_fa": 483,
+ "ˈw": 484,
+ "tʃʰjɛn": 485,
+ "_wɪn": 486,
+ "oʊɫd": 487,
+ "_əˈp": 488,
+ "aʊnd": 489,
+ "san": 490,
+ "he": 491,
+ "_bɪn": 492,
+ "fa": 493,
+ "ɪf": 494,
+ "ɔŋ": 495,
+ "ge": 496,
+ "_ɪn_ðə": 497,
+ "miŋ": 498,
+ "_pɹ": 499,
+ "ina": 500,
+ "ano": 501,
+ "əbəɫ": 502,
+ "kˈs": 503,
+ "_ˈɛni": 504,
+ "nəŋ": 505,
+ "əd": 506,
+ "_əv_ðə_ˈ": 507,
+ "_waɪ": 508,
+ "_taɪm": 509,
+ "ˈsɛɫ": 510,
+ "ʃiɛ": 511,
+ "_kəm": 512,
+ "æst": 513,
+ "_goʊ": 514,
+ "mɯ": 515,
+ "ˈp": 516,
+ "_ˈst": 517,
+ "ə_t": 518,
+ "pt": 519,
+ "_pʰ": 520,
+ "ʰɹ": 521,
+ "ʃja": 522,
+ "iwa": 523,
+ "ɪl": 524,
+ "bət": 525,
+ "_fɑŋ": 526,
+ "ho": 527,
+ "iv": 528,
+ "loʊ": 529,
+ "be": 530,
+ "_laɪk": 531,
+ "ɪʃ": 532,
+ "_fu": 533,
+ "ze": 534,
+ "ə_tʃ": 535,
+ "ɑɹt": 536,
+ "ɔɹd": 537,
+ "tʃʰiŋ": 538,
+ "mp": 539,
+ "_ðə_s": 540,
+ "_əˈbaʊt": 541,
+ "_ˈoʊ": 542,
+ "kʰə": 543,
+ "d_tɪ": 544,
+ "ŋga": 545,
+ "əli": 546,
+ "_kʰan": 547,
+ "çi": 548,
+ "_ˈju": 549,
+ "_kʊd": 550,
+ "ɔɫ": 551,
+ "ɔt": 552,
+ "_ɪts": 553,
+ "_san": 554,
+ "tʃa": 555,
+ "i_na": 556,
+ "xə": 557,
+ "ɛkt": 558,
+ "_mɔɹ": 559,
+ "te_kɯ": 560,
+ "ɪdʒ": 561,
+ "jʊŋ": 562,
+ "_wan": 563,
+ "æt": 564,
+ "kat": 565,
+ "ˈsɛɫf": 566,
+ "_ke": 567,
+ "aɪnd": 568,
+ "it": 569,
+ "_ɑɹ": 570,
+ "sp": 571,
+ "oʊnt": 572,
+ "_tʃi": 573,
+ "tsʰɹ": 574,
+ "_xən": 575,
+ "_əˈg": 576,
+ "ə_k": 577,
+ "to_i": 578,
+ "_tʰi": 579,
+ "_iŋ": 580,
+ "aʊn": 581,
+ "gɯ": 582,
+ "_ɪkˈs": 583,
+ "ɛv": 584,
+ "gi": 585,
+ "ks": 586,
+ "_səm": 587,
+ "ana": 588,
+ "ɪtəɫ": 589,
+ "nan": 590,
+ "_ˈɪntu": 591,
+ "_hiɹ": 592,
+ "_te": 593,
+ "_naʊ": 594,
+ "ʃiɑʊ": 595,
+ "ʃo": 596,
+ "ɹe": 597,
+ "xaɪ": 598,
+ "_tʃʰiŋ": 599,
+ "_sɹ": 600,
+ "_haʊ": 601,
+ "?.": 602,
+ "_feɪ": 603,
+ "liŋ": 604,
+ "_ʃja": 605,
+ "_ˈdʒ": 606,
+ "_seɪ": 607,
+ "ˈn": 608,
+ "soʊ": 609,
+ "tʰʊŋ": 610,
+ "_ljoʊ": 611,
+ "maɪ": 612,
+ "_bɹ": 613,
+ "ɹeɪt": 614,
+ "_nəŋ": 615,
+ "ʰə": 616,
+ "æns": 617,
+ "_ˈɔl": 618,
+ "tatʃi": 619,
+ "nto": 620,
+ "_ˌɪnˈ": 621,
+ "le": 622,
+ "nde": 623,
+ "_ˈvɛɹi": 624,
+ "mənt": 625,
+ "ɾima": 626,
+ "_ðɛn": 627,
+ "_həz": 628,
+ "_ɹi": 629,
+ "ftəɹ": 630,
+ "_sp": 631,
+ "ɾewa": 632,
+ "ga_a": 633,
+ "z_əv": 634,
+ "_miŋ": 635,
+ "_tɪ_ðə": 636,
+ "ɹaɪ": 637,
+ "ɛl": 638,
+ "ɹæ": 639,
+ "_hoʊ": 640,
+ "xu": 641,
+ "oʊnli": 642,
+ "ŋk": 643,
+ "i_i": 644,
+ "_dɪd": 645,
+ "_dʒɪst": 646,
+ "ing": 647,
+ "kai": 648,
+ "_mæn": 649,
+ "_in": 650,
+ "zo": 651,
+ "əf": 652,
+ "dake": 653,
+ "_ˈsəm": 654,
+ "ɾɯ_no": 655,
+ "_go": 656,
+ "tʃəɹ": 657,
+ "ite": 658,
+ "`↓.": 659,
+ "_kʰaɪ": 660,
+ "sk": 661,
+ "ɔɹs": 662,
+ "_tʰiŋ": 663,
+ "_nə": 664,
+ "pəɫ": 665,
+ "_tɪ_bi": 666,
+ "ˈfɔɹ": 667,
+ "mu": 668,
+ "su": 669,
+ "aa": 670,
+ "ɪstəɹ": 671,
+ "ʰan": 672,
+ "pəɹ": 673,
+ "ə_p": 674,
+ "liɑŋ": 675,
+ "_v": 676,
+ "oʊst": 677,
+ "_əˈgɛn": 678,
+ "ənz": 679,
+ "No": 680,
+ "ɔɹt": 681,
+ "_səˈ": 682,
+ "_mɯ": 683,
+ "tʃʰ": 684,
+ "_ˈlɪtəɫ": 685,
+ "_xwo": 686,
+ "_ˌbi": 687,
+ "_ˈoʊvəɹ": 688,
+ "_çi": 689,
+ "_deɪ": 690,
+ "aɪn": 691,
+ "_ʃiŋ": 692,
+ "i_ʃi": 693,
+ "_tsʰaɪ": 694,
+ "ʃoo": 695,
+ "ɾoo": 696,
+ "bəɹ": 697,
+ "ʰa": 698,
+ "ˈɛs": 699,
+ "_ɪn_ðə_ˈ": 700,
+ "Nwa": 701,
+ "_ðən": 702,
+ "saɪ": 703,
+ "_ˈjuˈɛs": 704,
+ "nda": 705,
+ "_pleɪ": 706,
+ "ɪŋ_tɪ": 707,
+ "ɪti": 708,
+ "_me": 709,
+ "_ʃʊd": 710,
+ "_nu": 711,
+ "_ðə_k": 712,
+ "za": 713,
+ "_ˈɛvəɹ": 714,
+ "əɹn": 715,
+ "æd": 716,
+ "ˈm": 717,
+ "_doʊnt": 718,
+ "_məst": 719,
+ "jɯɯ": 720,
+ "ɑɹd": 721,
+ "_jɛn": 722,
+ "ʃɥ": 723,
+ "_ˈoʊnli": 724,
+ "_ʃo": 725,
+ "_liŋ": 726,
+ "ss": 727,
+ "ɑl": 728,
+ "dea": 729,
+ "ɾeta": 730,
+ "mjɛn": 731,
+ "_gʊd": 732,
+ "_wɔ": 733,
+ "imo": 734,
+ "no_ko": 735,
+ "_ɥæn": 736,
+ "ndʒ": 737,
+ "ɪʃən": 738,
+ "o_ʃi": 739,
+ "_θɪŋk": 740,
+ "_nan": 741,
+ "to_o": 742,
+ "_tʰʊŋ": 743,
+ "ljoʊ": 744,
+ "tai": 745,
+ "mə_s": 746,
+ "_jɯ": 747,
+ "_uɑŋ": 748,
+ "_ˌbiˈfɔɹ": 749,
+ "æs": 750,
+ "_tʃʰjɛn": 751,
+ "ik": 752,
+ "_bæk": 753,
+ "_ˈiv": 754,
+ "eɪn": 755,
+ "un": 756,
+ "la": 757,
+ "ˈk": 758,
+ "_daʊn": 759,
+ "anai": 760,
+ "_lɛ": 761,
+ "əɹt": 762,
+ "ðɛɹ": 763,
+ "_ˈæftəɹ": 764,
+ "dat": 765,
+ "fan": 766,
+ "bəɫ": 767,
+ "temo": 768,
+ "tʰa": 769,
+ "ɾɯ_ko": 770,
+ "ˈv": 771,
+ "feɪ": 772,
+ "_mətʃ": 773,
+ "xwo": 774,
+ "ɹoʊ": 775,
+ "_ba": 776,
+ "_ˈnɛvəɹ": 777,
+ "_meɪd": 778,
+ "_jʊŋ": 779,
+ "_əˈpɑn": 780,
+ "!?": 781,
+ "_ˈʃ": 782,
+ "_ðə_ˈk": 783,
+ "ft": 784,
+ "_bo": 785,
+ "_ɪn_ə": 786,
+ "tʃʰɥæn": 787,
+ "ˈz": 788,
+ "`↓,": 789,
+ "_bɪˈk": 790,
+ "ɪg": 791,
+ "kin": 792,
+ "_kl": 793,
+ "ɾɯ_n": 794,
+ "_lɑʊ": 795,
+ "----": 796,
+ "ika": 797,
+ "_ɹaɪt": 798,
+ "zd": 799,
+ "z_ənd": 800,
+ "_kjo": 801,
+ "xwan": 802,
+ "too": 803,
+ "_gɪt": 804,
+ "_liɑŋ": 805,
+ "ta_n": 806,
+ "_keɪm": 807,
+ "_ˈəðəɹ": 808,
+ "_wɛɫ": 809,
+ "teki": 810,
+ "see": 811,
+ "jɯ": 812,
+ "i_o": 813,
+ "to_ʃi": 814,
+ "fəɫ": 815,
+ "bo": 816,
+ "ˌt": 817,
+ "ɪp": 818,
+ "ane": 819,
+ "_tʰjɛn": 820,
+ "_tʃo": 821,
+ "ɾjo": 822,
+ "ɪns": 823,
+ "_he": 824,
+ "ŋka": 825,
+ "ʃɥɛ": 826,
+ "dʑa": 827,
+ "vd": 828,
+ "ʰwan": 829,
+ "_gɹeɪt": 830,
+ "_əv_ə": 831,
+ "əndəɹ": 832,
+ "kedo": 833,
+ "_ðə_b": 834,
+ "ək": 835,
+ "_teɪk": 836,
+ "kʰan": 837,
+ "_ˈɔlˌ": 838,
+ "swo": 839,
+ "_ɪt_wɑz": 840,
+ "_ʃɥ": 841,
+ "_sim": 842,
+ "_ˈfɑ": 843,
+ "min": 844,
+ "i_a": 845,
+ "soo": 846,
+ "ɛns": 847,
+ "_sətʃ": 848,
+ "tʰaɪ": 849,
+ "_ga": 850,
+ "i_ka": 851,
+ "koo": 852,
+ "_fəɹst": 853,
+ "_ˈtʃ": 854,
+ "nno": 855,
+ "ə_ɹ": 856,
+ "taɾa": 857,
+ "tʃʰjoʊ": 858,
+ "_æm": 859,
+ "_mu": 860,
+ "_meɪk": 861,
+ "↓…": 862,
+ "ɪˈθ": 863,
+ "ɑb": 864,
+ "ɹa": 865,
+ "_wɛɹ": 866,
+ "_ðə_ˈs": 867,
+ "_əˈl": 868,
+ "_oʊɫd": 869,
+ "æl": 870,
+ "_ˈpi": 871,
+ "_lɔŋ": 872,
+ "dʑo": 873,
+ "_tʰaɪ": 874,
+ "ɔɹn": 875,
+ "əɫz": 876,
+ "_təˈ": 877,
+ "_əˈweɪ": 878,
+ "pa": 879,
+ "_ðiz": 880,
+ "_ˈsp": 881,
+ "nn": 882,
+ "mae": 883,
+ "towa": 884,
+ "ta_no": 885,
+ "_an": 886,
+ "kʰaɪ": 887,
+ "ɾaɾe": 888,
+ "eɪs": 889,
+ "ɑd": 890,
+ "_wɪˈθ": 891,
+ "_ˈivɪn": 892,
+ "_lu": 893,
+ "ɔɪ": 894,
+ "lɪŋ": 895,
+ "əti": 896,
+ "_ðə_f": 897,
+ "oʃi": 898,
+ "_la": 899,
+ "si": 900,
+ "tɪd": 901,
+ "haʊ": 902,
+ "pʰin": 903,
+ "ˈst": 904,
+ "_ˈpəɹ": 905,
+ "eɹ": 906,
+ "*!": 907,
+ "_ˈmɪstəɹ": 908,
+ "ʃa": 909,
+ "_ˌɪm": 910,
+ "ˌθɪŋ": 911,
+ "_neɪ": 912,
+ "_nɥ": 913,
+ "ɑk": 914,
+ "_ɹu": 915,
+ "_ʃɯ": 916,
+ "_ðə_ˈm": 917,
+ "demo": 918,
+ "_dɹ": 919,
+ "dʑoo": 920,
+ "_stɪɫ": 921,
+ "_pʰiŋ": 922,
+ "ə_i": 923,
+ "_ɪkˈsp": 924,
+ "_wɛnt": 925,
+ "ɪɹi": 926,
+ "əˈm": 927,
+ "o_ka": 928,
+ "_əˈk": 929,
+ "ɔk": 930,
+ "_ɥɛ": 931,
+ "_lʊk": 932,
+ "ˈd": 933,
+ "kaʃi": 934,
+ "_wɪθ_ə": 935,
+ "ljɛn": 936,
+ "ɔn": 937,
+ "_ljɛn": 938,
+ "_hɛɫ": 939,
+ "uɹ": 940,
+ "_tʰoʊ": 941,
+ "_tʃʰɥæn": 942,
+ "_sk": 943,
+ "tsʰaɪ": 944,
+ "ɛtəɹ": 945,
+ "_min": 946,
+ "noʊ": 947,
+ "ʃɯ": 948,
+ "_θɹu": 949,
+ "_θɔt": 950,
+ "dajo": 951,
+ "wi": 952,
+ "i_ko": 953,
+ "_tɹ": 954,
+ "_fan": 955,
+ "ɹɛ": 956,
+ "saN": 957,
+ "_hi_wɑz": 958,
+ "_ɾe": 959,
+ "_əm": 960,
+ "te_ki": 961,
+ "_xoʊ": 962,
+ "ˈl": 963,
+ "ˈg": 964,
+ "ga_i": 965,
+ "_ɔn_ðə": 966,
+ "_xwa": 967,
+ "vɪŋ": 968,
+ "man": 969,
+ "fəɹ": 970,
+ "_oʊn": 971,
+ "ˈɹ": 972,
+ "_kɹ": 973,
+ "te_o": 974,
+ "ɪli": 975,
+ "_ʃɥɛ": 976,
+ "_fəŋ": 977,
+ "æɫ": 978,
+ "ɑp": 979,
+ "_ˈɛv": 980,
+ "eɪndʒ": 981,
+ "iɫ": 982,
+ "wət": 983,
+ "ɛðəɹ": 984,
+ "_fən": 985,
+ "ɾee": 986,
+ "_hi_hæd": 987,
+ "_maɪt": 988,
+ "_ge": 989,
+ "ækt": 990,
+ "ɪts": 991,
+ "_hɪm": 992,
+ "_ze": 993,
+ "ii": 994,
+ "_N": 995,
+ "_əv_hɪz": 996,
+ "_gɹ": 997,
+ "ænt": 998,
+ "ɪˌ": 999,
+ "_hɪmˈsɛɫf": 1000,
+ "wa_na": 1001,
+ "aɪəɹ": 1002,
+ "dʑanai": 1003,
+ "kana": 1004,
+ "aɪz": 1005,
+ "_ɪt_ɪz": 1006,
+ "mase": 1007,
+ "wɪn": 1008,
+ "əθɪŋ": 1009,
+ "_pɹəˈ": 1010,
+ "kɯn": 1011,
+ "ˈju": 1012,
+ "_fɔɹ": 1013,
+ "pʰi": 1014,
+ "pʰiŋ": 1015,
+ "o_i": 1016,
+ "vz": 1017,
+ "ɔɪn": 1018,
+ "tʰiŋ": 1019,
+ "_ne": 1020,
+ "gəɹ": 1021,
+ "æts": 1022,
+ "_ˈɹi": 1023
+ },
+ "merges": [
+ "_ t",
+ "↓ ↑",
+ "_ ˈ",
+ "ə n",
+ "_ s",
+ "a ɪ",
+ "ə ɹ",
+ "e ɪ",
+ "o ʊ",
+ "_ k",
+ "ʃ i",
+ "_ w",
+ "_ ð",
+ "t s",
+ "t ʃ",
+ "_t s",
+ "_ h",
+ "_ ə",
+ "_ m",
+ "a n",
+ "_ n",
+ "_ð ə",
+ "ɛ n",
+ "ɑ ʊ",
+ "ɑ ŋ",
+ "` ⁼",
+ "_ p",
+ "_ i",
+ "_ ɪ",
+ "_t ʃ",
+ "_ l",
+ "j ɛn",
+ "_ d",
+ "_ f",
+ "_ j",
+ "w o",
+ "_ b",
+ "t a",
+ "` ↓",
+ "t e",
+ "ən d",
+ "_ ʃi",
+ "w a",
+ "k a",
+ "ɪ ŋ",
+ "i n",
+ "s t",
+ "l i",
+ "ʊ ŋ",
+ "_t ɪ",
+ "t o",
+ "w eɪ",
+ "_ ənd",
+ "ʰ i",
+ "_ə v",
+ "ə ŋ",
+ "n o",
+ "_ x",
+ "ɾ ɯ",
+ "n a",
+ "_ a",
+ "_ ɹ",
+ "ɪ n",
+ "g a",
+ "d e",
+ "j oʊ",
+ "æ n",
+ "k ɯ",
+ "ɾ e",
+ "m a",
+ "_ðə _ˈ",
+ "ɾ a",
+ "ɛ ɹ",
+ "m o",
+ "ɔ ɹ",
+ "ə ɫ",
+ "_ g",
+ "d a",
+ "* ↑",
+ "ɪ ˈ",
+ "_ o",
+ "_ ʃ",
+ "i ŋ",
+ "j a",
+ "ə m",
+ "_ ˌ",
+ "a ʊ",
+ "_ə ˈ",
+ "` ↑",
+ "ə t",
+ "_ aɪ",
+ "o o",
+ "s ɯ",
+ "↓ .",
+ "_ɪ n",
+ "_h i",
+ "_w ɪ",
+ "ɪ z",
+ "_n a",
+ "w an",
+ "_k o",
+ "_w o",
+ "ɪ d",
+ "ɾ i",
+ "_j u",
+ "m ə",
+ "_l ə",
+ "_h æ",
+ "_ðə t",
+ "ɑ ɹ",
+ "t ʰ",
+ "k i",
+ "… …",
+ "ɑ z",
+ "_ ɔ",
+ "_m i",
+ "_w ɑz",
+ "_ˈ s",
+ "↓ ,",
+ "_t ʰ",
+ "ə ˈ",
+ "d ʑ",
+ "ɪ t",
+ "_k ʰ",
+ "i ɛ",
+ "_m a",
+ "ɪ s",
+ "ts ɯ",
+ "_n i",
+ "_ɪ t",
+ "k e",
+ "i ɑʊ",
+ "_k a",
+ "_ əɹ",
+ "n d",
+ "_ˈ p",
+ "k o",
+ "j o",
+ "ɹ i",
+ "m ən",
+ "ʊ d",
+ "_ˈ m",
+ "_f əɹ",
+ "tʃ ʰi",
+ "s a",
+ "ʰ ɥ",
+ "k ʰ",
+ "ˈ s",
+ "ɑ t",
+ "ɛ d",
+ "s e",
+ "t ʃi",
+ "ɛ ɫ",
+ "_ˈ k",
+ "_j oʊ",
+ "t əɹ",
+ "ɛ z",
+ "- -",
+ "v əɹ",
+ "` →",
+ "ʃ ən",
+ "_ɪ z",
+ "_m eɪ",
+ "_ æ",
+ "d ʒ",
+ "_k i",
+ "_h ɪz",
+ "_b i",
+ "u ɑŋ",
+ "_ˈ f",
+ "↓↑ .",
+ "_wɪ θ",
+ "j u",
+ "i ɑŋ",
+ "→ .",
+ "_s o",
+ "_h əɹ",
+ "↑ .",
+ "n i",
+ "_m o",
+ "_m aɪ",
+ "l aɪ",
+ "ɥ ɛ",
+ "_t a",
+ "ən t",
+ "_tʃ ʰi",
+ "_s ɯ",
+ "_ θ",
+ "_ ɛz",
+ "w ən",
+ "m e",
+ "m i",
+ "_hæ d",
+ "_h a",
+ "ə s",
+ "_ˈ l",
+ "_s t",
+ "ð əɹ",
+ "oʊ n",
+ "_w a",
+ "ʰ əŋ",
+ "_n ɑt",
+ "* .",
+ "k t",
+ "_ˈ h",
+ "d o",
+ "ɥ æn",
+ "n e",
+ "_t o",
+ "_w ən",
+ "_n o",
+ "_l aɪ",
+ "_w əɹ",
+ "↑ ,",
+ "→ ,",
+ "ɛ s",
+ "↓↑ ,",
+ "_ɔ n",
+ "ʰ u",
+ "s o",
+ "_ˈ b",
+ "ɫ d",
+ "ɪ k",
+ "ɪ st",
+ "_f ɹ",
+ "_ð ɛɹ",
+ "_w eɪ",
+ "ka ɾa",
+ "_ˈ d",
+ "_hæ v",
+ "ts ʰ",
+ "w aɪ",
+ "ɾ o",
+ "ɛ m",
+ "_æ t",
+ "ʊ ɹ",
+ "_ˈ w",
+ "b a",
+ "_n oʊ",
+ "ʰ jɛn",
+ "ɹ eɪ",
+ "_j o",
+ "ɸ ɯ",
+ "_s a",
+ "_ɹ ɪˈ",
+ "_ˈ n",
+ "a i",
+ "_b ət",
+ "ɪ ɹ",
+ "tʃ ʰɥ",
+ "_d ʑ",
+ "ə ˌ",
+ "_ð ɪs",
+ ". .",
+ "x wa",
+ "_ɪ m",
+ "_d ɪˈ",
+ "_k ən",
+ "dʑ i",
+ "* ,",
+ "ɑ n",
+ "_ʃi ɑŋ",
+ "_k ɯ",
+ "ʃi n",
+ "_s oʊ",
+ "b i",
+ "tʰ jɛn",
+ "te _i",
+ "_ts ʰ",
+ "_ ɯ",
+ "aɪ t",
+ "ʰi ŋ",
+ "ð ə",
+ "_ɔ ɫ",
+ "_ˈ ɹ",
+ "na i",
+ "əɹ d",
+ "_ˈ t",
+ "_ ən",
+ "_tʃ ʰɥ",
+ "_i ɛ",
+ "l eɪ",
+ "ɛɹ i",
+ "ˈ t",
+ "h a",
+ "ʃi ŋ",
+ "ɛ vəɹ",
+ "z ɯ",
+ "_w i",
+ "_j a",
+ "ɛ k",
+ "ʰ ɑŋ",
+ "_ts ɯ",
+ "_əv _ðə",
+ "ta ʃi",
+ "_s ɛd",
+ "_x ə",
+ "_l i",
+ "_s i",
+ "de sɯ",
+ "_ˌ ɪn",
+ "ʃ jɛn",
+ "_b aɪ",
+ "o n",
+ "_x ɑʊ",
+ "_ð eɪ",
+ "_x aɪ",
+ "` ↓↑",
+ "x weɪ",
+ "h i",
+ "_s e",
+ "ə _s",
+ "_fɹ əm",
+ "ʊ t",
+ "d i",
+ "aʊ t",
+ "ə b",
+ "s ɹ",
+ "ə z",
+ "_x weɪ",
+ "_kʰ ə",
+ "ɹ u",
+ "_ u",
+ "_d e",
+ "aɪ d",
+ "ɪ v",
+ "b ɯ",
+ "_h o",
+ "əɹ z",
+ "j oo",
+ "_b ɪˈ",
+ "_tʰ a",
+ "ɛ t",
+ "e n",
+ "ɛn i",
+ "ə st",
+ "æ k",
+ "ə _ts",
+ "_ˈ ɪn",
+ "t i",
+ "ɥ n",
+ "_d ʒ",
+ "x ɑʊ",
+ "_ˈ v",
+ "ʃi ɑŋ",
+ "p ʰ",
+ "_wɪ tʃ",
+ "eɪ m",
+ "oʊ z",
+ "ə ðəɹ",
+ "f ɑŋ",
+ "_ˈ g",
+ "_d o",
+ "_ʃi ɑʊ",
+ "_ˈ æ",
+ "_j ʊɹ",
+ "_ð ɛm",
+ "ɪ m",
+ "ɛ st",
+ "æn d",
+ "_d u",
+ "ɯ ɯ",
+ "k an",
+ "_d a",
+ "in o",
+ "_ e",
+ "_w ʊd",
+ "ɛn d",
+ "m eɪ",
+ "θ ɪŋ",
+ "_ʃ jɛn",
+ "i z",
+ "aɪ m",
+ "_h u",
+ "_əˈ b",
+ "ən s",
+ "_wɪ ɫ",
+ "t ʰi",
+ "g o",
+ "ɛn t",
+ "f u",
+ "æ p",
+ "x oʊ",
+ "eɪ k",
+ "ʊ k",
+ "əɹ ˈ",
+ "_θ ɪŋ",
+ "ə l",
+ "p ɹ",
+ "ə tʃ",
+ "n t",
+ "_ ɸɯ",
+ "l u",
+ "_ˈ ɔ",
+ "_i ɑʊ",
+ "l ə",
+ "t u",
+ "_dʑ i",
+ "eɪ t",
+ "_ʃi n",
+ "n na",
+ "_ˈp ɹ",
+ "f ən",
+ "_ə p",
+ "n jɛn",
+ "_a ʊt",
+ "f ɔɹ",
+ "_t u",
+ "eɪ ʃən",
+ "ɪ ɫ",
+ "_w ət",
+ "_ɪ f",
+ "_ ɥ",
+ "_f a",
+ "ˈ w",
+ "tʃ ʰjɛn",
+ "_w ɪn",
+ "oʊ ɫd",
+ "_əˈ p",
+ "aʊ nd",
+ "s an",
+ "h e",
+ "_b ɪn",
+ "f a",
+ "ɪ f",
+ "ɔ ŋ",
+ "g e",
+ "_ɪn _ðə",
+ "m iŋ",
+ "_p ɹ",
+ "in a",
+ "an o",
+ "əb əɫ",
+ "k ˈs",
+ "_ˈ ɛni",
+ "n əŋ",
+ "ə d",
+ "_əv _ðə_ˈ",
+ "_w aɪ",
+ "_t aɪm",
+ "ˈs ɛɫ",
+ "ʃi ɛ",
+ "_k əm",
+ "æ st",
+ "_g oʊ",
+ "m ɯ",
+ "ˈ p",
+ "_ˈ st",
+ "ə _t",
+ "p t",
+ "_p ʰ",
+ "ʰ ɹ",
+ "ʃ ja",
+ "i wa",
+ "ɪ l",
+ "b ət",
+ "_f ɑŋ",
+ "h o",
+ "i v",
+ "l oʊ",
+ "b e",
+ "_laɪ k",
+ "ɪ ʃ",
+ "_f u",
+ "z e",
+ "ə _tʃ",
+ "ɑɹ t",
+ "ɔɹ d",
+ "tʃʰi ŋ",
+ "m p",
+ "_ðə _s",
+ "_əˈb aʊt",
+ "_ˈ oʊ",
+ "kʰ ə",
+ "d _tɪ",
+ "ŋ ga",
+ "ə li",
+ "_kʰ an",
+ "ç i",
+ "_ˈ ju",
+ "_k ʊd",
+ "ɔ ɫ",
+ "ɔ t",
+ "_ɪ ts",
+ "_s an",
+ "tʃ a",
+ "i _na",
+ "x ə",
+ "ɛ kt",
+ "_m ɔɹ",
+ "te _kɯ",
+ "ɪd ʒ",
+ "j ʊŋ",
+ "_w an",
+ "æ t",
+ "ka t",
+ "ˈsɛɫ f",
+ "_k e",
+ "aɪ nd",
+ "i t",
+ "_ ɑɹ",
+ "s p",
+ "oʊn t",
+ "_t ʃi",
+ "tsʰ ɹ",
+ "_x ən",
+ "_əˈ g",
+ "ə _k",
+ "to _i",
+ "_t ʰi",
+ "_i ŋ",
+ "aʊ n",
+ "g ɯ",
+ "_ɪ kˈs",
+ "ɛ v",
+ "g i",
+ "k s",
+ "_s əm",
+ "an a",
+ "ɪt əɫ",
+ "n an",
+ "_ˈɪn tu",
+ "_hi ɹ",
+ "_t e",
+ "_n aʊ",
+ "ʃi ɑʊ",
+ "ʃ o",
+ "ɹ e",
+ "x aɪ",
+ "_tʃʰi ŋ",
+ "_s ɹ",
+ "_h aʊ",
+ "? .",
+ "_f eɪ",
+ "li ŋ",
+ "_ʃ ja",
+ "_ˈ dʒ",
+ "_s eɪ",
+ "ˈ n",
+ "s oʊ",
+ "tʰ ʊŋ",
+ "_l joʊ",
+ "m aɪ",
+ "_b ɹ",
+ "ɹeɪ t",
+ "_n əŋ",
+ "ʰ ə",
+ "æn s",
+ "_ˈɔ l",
+ "ta tʃi",
+ "n to",
+ "_ˌɪn ˈ",
+ "l e",
+ "n de",
+ "_ˈv ɛɹi",
+ "mən t",
+ "ɾi ma",
+ "_ð ɛn",
+ "_h əz",
+ "_ɹ i",
+ "f təɹ",
+ "_s p",
+ "ɾe wa",
+ "ga _a",
+ "z _əv",
+ "_m iŋ",
+ "_tɪ _ðə",
+ "ɹ aɪ",
+ "ɛ l",
+ "ɹ æ",
+ "_h oʊ",
+ "x u",
+ "oʊn li",
+ "ŋ k",
+ "i _i",
+ "_d ɪd",
+ "_dʒ ɪst",
+ "in g",
+ "ka i",
+ "_m æn",
+ "_i n",
+ "z o",
+ "ə f",
+ "da ke",
+ "_ˈs əm",
+ "ɾɯ _no",
+ "_g o",
+ "tʃ əɹ",
+ "i te",
+ "`↓ .",
+ "_kʰ aɪ",
+ "s k",
+ "ɔɹ s",
+ "_t ʰiŋ",
+ "_n ə",
+ "p əɫ",
+ "_tɪ _bi",
+ "ˈ fɔɹ",
+ "m u",
+ "s u",
+ "a a",
+ "ɪst əɹ",
+ "ʰ an",
+ "p əɹ",
+ "ə _p",
+ "li ɑŋ",
+ "_ v",
+ "oʊ st",
+ "_əˈg ɛn",
+ "ən z",
+ "N o",
+ "ɔɹ t",
+ "_s əˈ",
+ "_m ɯ",
+ "tʃ ʰ",
+ "_ˈl ɪtəɫ",
+ "_x wo",
+ "_ˌ bi",
+ "_ˈoʊ vəɹ",
+ "_ çi",
+ "_d eɪ",
+ "aɪ n",
+ "_ʃi ŋ",
+ "i _ʃi",
+ "_tsʰ aɪ",
+ "ʃ oo",
+ "ɾ oo",
+ "b əɹ",
+ "ʰ a",
+ "ˈ ɛs",
+ "_ɪn _ðə_ˈ",
+ "N wa",
+ "_ð ən",
+ "s aɪ",
+ "_ˈju ˈɛs",
+ "n da",
+ "_p leɪ",
+ "ɪŋ _tɪ",
+ "ɪt i",
+ "_m e",
+ "_ʃ ʊd",
+ "_n u",
+ "_ðə _k",
+ "z a",
+ "_ˈ ɛvəɹ",
+ "əɹ n",
+ "æ d",
+ "ˈ m",
+ "_d oʊnt",
+ "_m əst",
+ "j ɯɯ",
+ "ɑɹ d",
+ "_ jɛn",
+ "ʃ ɥ",
+ "_ˈ oʊnli",
+ "_ʃ o",
+ "_l iŋ",
+ "s s",
+ "ɑ l",
+ "de a",
+ "ɾe ta",
+ "m jɛn",
+ "_g ʊd",
+ "_w ɔ",
+ "i mo",
+ "no _ko",
+ "_ ɥæn",
+ "nd ʒ",
+ "ɪ ʃən",
+ "o _ʃi",
+ "_θɪŋ k",
+ "_n an",
+ "to _o",
+ "_tʰ ʊŋ",
+ "l joʊ",
+ "ta i",
+ "mə _s",
+ "_j ɯ",
+ "_ uɑŋ",
+ "_ˌbi ˈfɔɹ",
+ "æ s",
+ "_tʃ ʰjɛn",
+ "i k",
+ "_b æk",
+ "_ˈ iv",
+ "eɪ n",
+ "u n",
+ "l a",
+ "ˈ k",
+ "_d aʊn",
+ "an ai",
+ "_l ɛ",
+ "əɹ t",
+ "ð ɛɹ",
+ "_ˈæ ftəɹ",
+ "da t",
+ "f an",
+ "b əɫ",
+ "te mo",
+ "tʰ a",
+ "ɾɯ _ko",
+ "ˈ v",
+ "f eɪ",
+ "_m ətʃ",
+ "x wo",
+ "ɹ oʊ",
+ "_b a",
+ "_ˈn ɛvəɹ",
+ "_meɪ d",
+ "_j ʊŋ",
+ "_əˈp ɑn",
+ "! ?",
+ "_ˈ ʃ",
+ "_ðə_ˈ k",
+ "f t",
+ "_b o",
+ "_ɪn _ə",
+ "tʃʰɥ æn",
+ "ˈ z",
+ "`↓ ,",
+ "_bɪˈ k",
+ "ɪ g",
+ "k in",
+ "_k l",
+ "ɾɯ _n",
+ "_l ɑʊ",
+ "-- --",
+ "i ka",
+ "_ɹ aɪt",
+ "z d",
+ "z _ənd",
+ "_k jo",
+ "x wan",
+ "to o",
+ "_g ɪt",
+ "_l iɑŋ",
+ "ta _n",
+ "_k eɪm",
+ "_ˈ əðəɹ",
+ "_w ɛɫ",
+ "te ki",
+ "se e",
+ "j ɯ",
+ "i _o",
+ "to _ʃi",
+ "f əɫ",
+ "b o",
+ "ˌ t",
+ "ɪ p",
+ "an e",
+ "_tʰ jɛn",
+ "_tʃ o",
+ "ɾ jo",
+ "ɪn s",
+ "_h e",
+ "ŋ ka",
+ "ʃ ɥɛ",
+ "dʑ a",
+ "v d",
+ "ʰ wan",
+ "_g ɹeɪt",
+ "_əv _ə",
+ "ənd əɹ",
+ "ke do",
+ "_ðə _b",
+ "ə k",
+ "_t eɪk",
+ "kʰ an",
+ "_ˈɔl ˌ",
+ "s wo",
+ "_ɪt _wɑz",
+ "_ʃ ɥ",
+ "_si m",
+ "_ˈf ɑ",
+ "m in",
+ "i _a",
+ "s oo",
+ "ɛn s",
+ "_s ətʃ",
+ "tʰ aɪ",
+ "_ ga",
+ "i _ka",
+ "k oo",
+ "_fəɹ st",
+ "_ˈ tʃ",
+ "n no",
+ "ə _ɹ",
+ "ta ɾa",
+ "tʃʰ joʊ",
+ "_æ m",
+ "_m u",
+ "_meɪ k",
+ "↓ …",
+ "ɪˈ θ",
+ "ɑ b",
+ "ɹ a",
+ "_w ɛɹ",
+ "_ðə_ˈ s",
+ "_əˈ l",
+ "_ oʊɫd",
+ "æ l",
+ "_ˈp i",
+ "_l ɔŋ",
+ "dʑ o",
+ "_tʰ aɪ",
+ "ɔɹ n",
+ "əɫ z",
+ "_t əˈ",
+ "_əˈ weɪ",
+ "p a",
+ "_ð iz",
+ "_ˈs p",
+ "n n",
+ "ma e",
+ "to wa",
+ "ta _no",
+ "_ an",
+ "kʰ aɪ",
+ "ɾa ɾe",
+ "eɪ s",
+ "ɑ d",
+ "_w ɪˈθ",
+ "_ˈiv ɪn",
+ "_l u",
+ "ɔ ɪ",
+ "l ɪŋ",
+ "ət i",
+ "_ðə _f",
+ "o ʃi",
+ "_l a",
+ "s i",
+ "t ɪd",
+ "h aʊ",
+ "pʰ in",
+ "ˈ st",
+ "_ˈp əɹ",
+ "e ɹ",
+ "* !",
+ "_ˈm ɪstəɹ",
+ "ʃ a",
+ "_ˌ ɪm",
+ "ˌ θɪŋ",
+ "_n eɪ",
+ "_n ɥ",
+ "ɑ k",
+ "_ɹ u",
+ "_ʃ ɯ",
+ "_ðə_ˈ m",
+ "de mo",
+ "_d ɹ",
+ "dʑ oo",
+ "_st ɪɫ",
+ "_p ʰiŋ",
+ "ə _i",
+ "_ɪkˈs p",
+ "_w ɛnt",
+ "ɪ ɹi",
+ "əˈ m",
+ "o _ka",
+ "_əˈ k",
+ "ɔ k",
+ "_ ɥɛ",
+ "_l ʊk",
+ "ˈ d",
+ "ka ʃi",
+ "_wɪθ _ə",
+ "l jɛn",
+ "ɔ n",
+ "_l jɛn",
+ "_h ɛɫ",
+ "u ɹ",
+ "_tʰ oʊ",
+ "_tʃʰɥ æn",
+ "_s k",
+ "tsʰ aɪ",
+ "ɛ təɹ",
+ "_m in",
+ "n oʊ",
+ "ʃ ɯ",
+ "_θ ɹu",
+ "_θ ɔt",
+ "da jo",
+ "w i",
+ "i _ko",
+ "_t ɹ",
+ "_f an",
+ "ɹ ɛ",
+ "sa N",
+ "_hi _wɑz",
+ "_ ɾe",
+ "_ə m",
+ "te _ki",
+ "_x oʊ",
+ "ˈ l",
+ "ˈ g",
+ "ga _i",
+ "_ɔn _ðə",
+ "_x wa",
+ "v ɪŋ",
+ "m an",
+ "f əɹ",
+ "_ oʊn",
+ "ˈ ɹ",
+ "_k ɹ",
+ "te _o",
+ "ɪ li",
+ "_ʃ ɥɛ",
+ "_f əŋ",
+ "æ ɫ",
+ "ɑ p",
+ "_ˈ ɛv",
+ "eɪ ndʒ",
+ "i ɫ",
+ "w ət",
+ "ɛ ðəɹ",
+ "_f ən",
+ "ɾe e",
+ "_hi _hæd",
+ "_maɪ t",
+ "_g e",
+ "æ kt",
+ "ɪ ts",
+ "_h ɪm",
+ "_ ze",
+ "i i",
+ "_ N",
+ "_əv _hɪz",
+ "_g ɹ",
+ "æn t",
+ "ɪ ˌ",
+ "_hɪm ˈsɛɫf",
+ "wa _na",
+ "aɪ əɹ",
+ "dʑ anai",
+ "kan a",
+ "aɪ z",
+ "_ɪt _ɪz",
+ "ma se",
+ "w ɪn",
+ "ə θɪŋ",
+ "_pɹ əˈ",
+ "kɯ n",
+ "ˈ ju",
+ "_f ɔɹ",
+ "p ʰi",
+ "p ʰiŋ",
+ "o _i",
+ "v z",
+ "ɔ ɪn",
+ "t ʰiŋ",
+ "_n e",
+ "g əɹ",
+ "æ ts",
+ "_ˈ ɹi"
+ ]
+ }
+}
\ No newline at end of file
diff --git a/utils/g2p/bpe_69.json b/utils/g2p/bpe_69.json
new file mode 100644
index 0000000000000000000000000000000000000000..45eba9955fdcf9a7c764027b43930bf28722492b
--- /dev/null
+++ b/utils/g2p/bpe_69.json
@@ -0,0 +1,141 @@
+{
+ "version": "1.0",
+ "truncation": null,
+ "padding": null,
+ "added_tokens": [
+ {
+ "id": 0,
+ "content": "[UNK]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false,
+ "special": true
+ },
+ {
+ "id": 1,
+ "content": "[CLS]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false,
+ "special": true
+ },
+ {
+ "id": 2,
+ "content": "[SEP]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false,
+ "special": true
+ },
+ {
+ "id": 3,
+ "content": "[PAD]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false,
+ "special": true
+ },
+ {
+ "id": 4,
+ "content": "[MASK]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false,
+ "special": true
+ }
+ ],
+ "normalizer": null,
+ "pre_tokenizer": {
+ "type": "Whitespace"
+ },
+ "post_processor": null,
+ "decoder": null,
+ "model": {
+ "type": "BPE",
+ "dropout": null,
+ "unk_token": "[UNK]",
+ "continuing_subword_prefix": null,
+ "end_of_word_suffix": null,
+ "fuse_unk": false,
+ "byte_fallback": false,
+ "vocab": {
+ "[UNK]": 0,
+ "[CLS]": 1,
+ "[SEP]": 2,
+ "[PAD]": 3,
+ "[MASK]": 4,
+ "!": 5,
+ "#": 6,
+ "*": 7,
+ ",": 8,
+ "-": 9,
+ ".": 10,
+ "=": 11,
+ "?": 12,
+ "N": 13,
+ "Q": 14,
+ "^": 15,
+ "_": 16,
+ "`": 17,
+ "a": 18,
+ "b": 19,
+ "d": 20,
+ "e": 21,
+ "f": 22,
+ "g": 23,
+ "h": 24,
+ "i": 25,
+ "j": 26,
+ "k": 27,
+ "l": 28,
+ "m": 29,
+ "n": 30,
+ "o": 31,
+ "p": 32,
+ "s": 33,
+ "t": 34,
+ "u": 35,
+ "v": 36,
+ "w": 37,
+ "x": 38,
+ "y": 39,
+ "z": 40,
+ "~": 41,
+ "æ": 42,
+ "ç": 43,
+ "ð": 44,
+ "ŋ": 45,
+ "ɑ": 46,
+ "ɔ": 47,
+ "ə": 48,
+ "ɛ": 49,
+ "ɥ": 50,
+ "ɪ": 51,
+ "ɫ": 52,
+ "ɯ": 53,
+ "ɸ": 54,
+ "ɹ": 55,
+ "ɾ": 56,
+ "ʃ": 57,
+ "ʊ": 58,
+ "ʑ": 59,
+ "ʒ": 60,
+ "ʰ": 61,
+ "ˈ": 62,
+ "ˌ": 63,
+ "θ": 64,
+ "…": 65,
+ "⁼": 66,
+ "↑": 67,
+ "→": 68,
+ "↓": 69
+ },
+ "merges": [
+ ]
+ }
+}
\ No newline at end of file
diff --git a/utils/g2p/cleaners.py b/utils/g2p/cleaners.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bb53977673afa893476cab1b946d6a9a00f57ea
--- /dev/null
+++ b/utils/g2p/cleaners.py
@@ -0,0 +1,61 @@
+import re
+from utils.g2p.japanese import japanese_to_romaji_with_accent, japanese_to_ipa, japanese_to_ipa2, japanese_to_ipa3
+from utils.g2p.mandarin import number_to_chinese, chinese_to_bopomofo, latin_to_bopomofo, chinese_to_romaji, chinese_to_lazy_ipa, chinese_to_ipa, chinese_to_ipa2
+from utils.g2p.english import english_to_lazy_ipa, english_to_ipa2, english_to_lazy_ipa2
+patterns = [r'\[EN\](.*?)\[EN\]', r'\[ZH\](.*?)\[ZH\]', r'\[JA\](.*?)\[JA\]']
+def japanese_cleaners(text):
+ text = japanese_to_romaji_with_accent(text)
+ text = re.sub(r'([A-Za-z])$', r'\1.', text)
+ return text
+
+def japanese_cleaners2(text):
+ return japanese_cleaners(text).replace('ts', 'ʦ').replace('...', '…')
+
+def chinese_cleaners(text):
+ '''Pipeline for Chinese text'''
+ text = number_to_chinese(text)
+ text = chinese_to_bopomofo(text)
+ text = latin_to_bopomofo(text)
+ text = re.sub(r'([ˉˊˇˋ˙])$', r'\1。', text)
+ return text
+
+def cje_cleaners(text):
+ matches = []
+ for pattern in patterns:
+ matches.extend(re.finditer(pattern, text))
+
+ matches.sort(key=lambda x: x.start()) # Sort matches by their start positions
+
+ outputs = ""
+ output_langs = []
+
+ for match in matches:
+ text_segment = text[match.start():match.end()]
+ phon = clean_one(text_segment)
+ if "[EN]" in text_segment:
+ lang = 'en'
+ elif "[ZH]" in text_segment:
+ lang = 'zh'
+ elif "[JA]" in text_segment:
+ lang = 'ja'
+ else:
+ raise ValueError("If you see this error, please report this bug to issues.")
+ outputs += phon
+ output_langs += [lang] * len(phon)
+ assert len(outputs) == len(output_langs)
+ return outputs, output_langs
+
+
+def clean_one(text):
+ if text.find('[ZH]') != -1:
+ text = re.sub(r'\[ZH\](.*?)\[ZH\]',
+ lambda x: chinese_to_ipa(x.group(1))+' ', text)
+ if text.find('[JA]') != -1:
+ text = re.sub(r'\[JA\](.*?)\[JA\]',
+ lambda x: japanese_to_ipa2(x.group(1))+' ', text)
+ if text.find('[EN]') != -1:
+ text = re.sub(r'\[EN\](.*?)\[EN\]',
+ lambda x: english_to_ipa2(x.group(1))+' ', text)
+ text = re.sub(r'\s+$', '', text)
+ text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
+ return text
diff --git a/utils/g2p/english.py b/utils/g2p/english.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ac2166d74ce2e24ec5eb844a186d18bf29065d3
--- /dev/null
+++ b/utils/g2p/english.py
@@ -0,0 +1,188 @@
+""" from https://github.com/keithito/tacotron """
+
+'''
+Cleaners are transformations that run over the input text at both training and eval time.
+
+Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
+hyperparameter. Some cleaners are English-specific. You'll typically want to use:
+ 1. "english_cleaners" for English text
+ 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
+ the Unidecode library (https://pypi.python.org/pypi/Unidecode)
+ 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
+ the symbols in symbols.py to match your data).
+'''
+
+
+# Regular expression matching whitespace:
+
+
+import re
+from unidecode import unidecode
+import inflect
+_inflect = inflect.engine()
+_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
+_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
+_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
+_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
+_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
+_number_re = re.compile(r'[0-9]+')
+
+# List of (regular expression, replacement) pairs for abbreviations:
+_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
+ ('mrs', 'misess'),
+ ('mr', 'mister'),
+ ('dr', 'doctor'),
+ ('st', 'saint'),
+ ('co', 'company'),
+ ('jr', 'junior'),
+ ('maj', 'major'),
+ ('gen', 'general'),
+ ('drs', 'doctors'),
+ ('rev', 'reverend'),
+ ('lt', 'lieutenant'),
+ ('hon', 'honorable'),
+ ('sgt', 'sergeant'),
+ ('capt', 'captain'),
+ ('esq', 'esquire'),
+ ('ltd', 'limited'),
+ ('col', 'colonel'),
+ ('ft', 'fort'),
+]]
+
+
+# List of (ipa, lazy ipa) pairs:
+_lazy_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
+ ('r', 'ɹ'),
+ ('æ', 'e'),
+ ('ɑ', 'a'),
+ ('ɔ', 'o'),
+ ('ð', 'z'),
+ ('θ', 's'),
+ ('ɛ', 'e'),
+ ('ɪ', 'i'),
+ ('ʊ', 'u'),
+ ('ʒ', 'ʥ'),
+ ('ʤ', 'ʥ'),
+ ('ˈ', '↓'),
+]]
+
+# List of (ipa, lazy ipa2) pairs:
+_lazy_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
+ ('r', 'ɹ'),
+ ('ð', 'z'),
+ ('θ', 's'),
+ ('ʒ', 'ʑ'),
+ ('ʤ', 'dʑ'),
+ ('ˈ', '↓'),
+]]
+
+# List of (ipa, ipa2) pairs
+_ipa_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
+ ('r', 'ɹ'),
+ ('ʤ', 'dʒ'),
+ ('ʧ', 'tʃ')
+]]
+
+
+def expand_abbreviations(text):
+ for regex, replacement in _abbreviations:
+ text = re.sub(regex, replacement, text)
+ return text
+
+
+def collapse_whitespace(text):
+ return re.sub(r'\s+', ' ', text)
+
+
+def _remove_commas(m):
+ return m.group(1).replace(',', '')
+
+
+def _expand_decimal_point(m):
+ return m.group(1).replace('.', ' point ')
+
+
+def _expand_dollars(m):
+ match = m.group(1)
+ parts = match.split('.')
+ if len(parts) > 2:
+ return match + ' dollars' # Unexpected format
+ dollars = int(parts[0]) if parts[0] else 0
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
+ if dollars and cents:
+ dollar_unit = 'dollar' if dollars == 1 else 'dollars'
+ cent_unit = 'cent' if cents == 1 else 'cents'
+ return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
+ elif dollars:
+ dollar_unit = 'dollar' if dollars == 1 else 'dollars'
+ return '%s %s' % (dollars, dollar_unit)
+ elif cents:
+ cent_unit = 'cent' if cents == 1 else 'cents'
+ return '%s %s' % (cents, cent_unit)
+ else:
+ return 'zero dollars'
+
+
+def _expand_ordinal(m):
+ return _inflect.number_to_words(m.group(0))
+
+
+def _expand_number(m):
+ num = int(m.group(0))
+ if num > 1000 and num < 3000:
+ if num == 2000:
+ return 'two thousand'
+ elif num > 2000 and num < 2010:
+ return 'two thousand ' + _inflect.number_to_words(num % 100)
+ elif num % 100 == 0:
+ return _inflect.number_to_words(num // 100) + ' hundred'
+ else:
+ return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
+ else:
+ return _inflect.number_to_words(num, andword='')
+
+
+def normalize_numbers(text):
+ text = re.sub(_comma_number_re, _remove_commas, text)
+ text = re.sub(_pounds_re, r'\1 pounds', text)
+ text = re.sub(_dollars_re, _expand_dollars, text)
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
+ text = re.sub(_number_re, _expand_number, text)
+ return text
+
+
+def mark_dark_l(text):
+ return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ'+x.group(1), text)
+
+
+def english_to_ipa(text):
+ import eng_to_ipa as ipa
+ text = unidecode(text).lower()
+ text = expand_abbreviations(text)
+ text = normalize_numbers(text)
+ phonemes = ipa.convert(text)
+ phonemes = collapse_whitespace(phonemes)
+ return phonemes
+
+
+def english_to_lazy_ipa(text):
+ text = english_to_ipa(text)
+ for regex, replacement in _lazy_ipa:
+ text = re.sub(regex, replacement, text)
+ return text
+
+
+def english_to_ipa2(text):
+ text = english_to_ipa(text)
+ text = mark_dark_l(text)
+ for regex, replacement in _ipa_to_ipa2:
+ text = re.sub(regex, replacement, text)
+ return text.replace('...', '…')
+
+
+def english_to_lazy_ipa2(text):
+ text = english_to_ipa(text)
+ for regex, replacement in _lazy_ipa2:
+ text = re.sub(regex, replacement, text)
+ return text
diff --git a/utils/g2p/japanese.py b/utils/g2p/japanese.py
new file mode 100644
index 0000000000000000000000000000000000000000..75716c69496397e1d03fd4c2e87a38860404d11b
--- /dev/null
+++ b/utils/g2p/japanese.py
@@ -0,0 +1,154 @@
+import re
+from unidecode import unidecode
+
+
+
+# Regular expression matching Japanese without punctuation marks:
+_japanese_characters = re.compile(
+ r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
+
+# Regular expression matching non-Japanese characters or punctuation marks:
+_japanese_marks = re.compile(
+ r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
+
+# List of (symbol, Japanese) pairs for marks:
+_symbols_to_japanese = [(re.compile('%s' % x[0]), x[1]) for x in [
+ ('%', 'パーセント')
+]]
+
+# List of (romaji, ipa) pairs for marks:
+_romaji_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
+ ('ts', 'ʦ'),
+ ('u', 'ɯ'),
+ ('j', 'ʥ'),
+ ('y', 'j'),
+ ('ni', 'n^i'),
+ ('nj', 'n^'),
+ ('hi', 'çi'),
+ ('hj', 'ç'),
+ ('f', 'ɸ'),
+ ('I', 'i*'),
+ ('U', 'ɯ*'),
+ ('r', 'ɾ')
+]]
+
+# List of (romaji, ipa2) pairs for marks:
+_romaji_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
+ ('u', 'ɯ'),
+ ('ʧ', 'tʃ'),
+ ('j', 'dʑ'),
+ ('y', 'j'),
+ ('ni', 'n^i'),
+ ('nj', 'n^'),
+ ('hi', 'çi'),
+ ('hj', 'ç'),
+ ('f', 'ɸ'),
+ ('I', 'i*'),
+ ('U', 'ɯ*'),
+ ('r', 'ɾ')
+]]
+
+# List of (consonant, sokuon) pairs:
+_real_sokuon = [(re.compile('%s' % x[0]), x[1]) for x in [
+ (r'Q([↑↓]*[kg])', r'k#\1'),
+ (r'Q([↑↓]*[tdjʧ])', r't#\1'),
+ (r'Q([↑↓]*[sʃ])', r's\1'),
+ (r'Q([↑↓]*[pb])', r'p#\1')
+]]
+
+# List of (consonant, hatsuon) pairs:
+_real_hatsuon = [(re.compile('%s' % x[0]), x[1]) for x in [
+ (r'N([↑↓]*[pbm])', r'm\1'),
+ (r'N([↑↓]*[ʧʥj])', r'n^\1'),
+ (r'N([↑↓]*[tdn])', r'n\1'),
+ (r'N([↑↓]*[kg])', r'ŋ\1')
+]]
+
+
+def symbols_to_japanese(text):
+ for regex, replacement in _symbols_to_japanese:
+ text = re.sub(regex, replacement, text)
+ return text
+
+
+def japanese_to_romaji_with_accent(text):
+ '''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html'''
+ import pyopenjtalk
+ text = symbols_to_japanese(text)
+ sentences = re.split(_japanese_marks, text)
+ marks = re.findall(_japanese_marks, text)
+ text = ''
+ for i, sentence in enumerate(sentences):
+ if re.match(_japanese_characters, sentence):
+ if text != '':
+ text += ' '
+ labels = pyopenjtalk.extract_fullcontext(sentence)
+ for n, label in enumerate(labels):
+ phoneme = re.search(r'\-([^\+]*)\+', label).group(1)
+ if phoneme not in ['sil', 'pau']:
+ text += phoneme.replace('ch', 'ʧ').replace('sh',
+ 'ʃ').replace('cl', 'Q')
+ else:
+ continue
+ # n_moras = int(re.search(r'/F:(\d+)_', label).group(1))
+ a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1))
+ a2 = int(re.search(r"\+(\d+)\+", label).group(1))
+ a3 = int(re.search(r"\+(\d+)/", label).group(1))
+ if re.search(r'\-([^\+]*)\+', labels[n + 1]).group(1) in ['sil', 'pau']:
+ a2_next = -1
+ else:
+ a2_next = int(
+ re.search(r"\+(\d+)\+", labels[n + 1]).group(1))
+ # Accent phrase boundary
+ if a3 == 1 and a2_next == 1:
+ text += ' '
+ # Falling
+ elif a1 == 0 and a2_next == a2 + 1:
+ text += '↓'
+ # Rising
+ elif a2 == 1 and a2_next == 2:
+ text += '↑'
+ if i < len(marks):
+ text += unidecode(marks[i]).replace(' ', '')
+ return text
+
+
+def get_real_sokuon(text):
+ for regex, replacement in _real_sokuon:
+ text = re.sub(regex, replacement, text)
+ return text
+
+
+def get_real_hatsuon(text):
+ for regex, replacement in _real_hatsuon:
+ text = re.sub(regex, replacement, text)
+ return text
+
+
+def japanese_to_ipa(text):
+ text = japanese_to_romaji_with_accent(text).replace('...', '…')
+ text = re.sub(
+ r'([aiueo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text)
+ text = get_real_sokuon(text)
+ text = get_real_hatsuon(text)
+ for regex, replacement in _romaji_to_ipa:
+ text = re.sub(regex, replacement, text)
+ return text
+
+
+def japanese_to_ipa2(text):
+ text = japanese_to_romaji_with_accent(text).replace('...', '…')
+ text = get_real_sokuon(text)
+ text = get_real_hatsuon(text)
+ for regex, replacement in _romaji_to_ipa2:
+ text = re.sub(regex, replacement, text)
+ return text
+
+
+def japanese_to_ipa3(text):
+ text = japanese_to_ipa2(text).replace('n^', 'ȵ').replace(
+ 'ʃ', 'ɕ').replace('*', '\u0325').replace('#', '\u031a')
+ text = re.sub(
+ r'([aiɯeo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text)
+ text = re.sub(r'((?:^|\s)(?:ts|tɕ|[kpt]))', r'\1ʰ', text)
+ return text
diff --git a/utils/g2p/mandarin.py b/utils/g2p/mandarin.py
new file mode 100644
index 0000000000000000000000000000000000000000..da7680b7a4e65de8cac1c9afd9a271b0bc666a7c
--- /dev/null
+++ b/utils/g2p/mandarin.py
@@ -0,0 +1,326 @@
+import os
+import sys
+import re
+import jieba
+import cn2an
+import logging
+
+
+# List of (Latin alphabet, bopomofo) pairs:
+_latin_to_bopomofo = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
+ ('a', 'ㄟˉ'),
+ ('b', 'ㄅㄧˋ'),
+ ('c', 'ㄙㄧˉ'),
+ ('d', 'ㄉㄧˋ'),
+ ('e', 'ㄧˋ'),
+ ('f', 'ㄝˊㄈㄨˋ'),
+ ('g', 'ㄐㄧˋ'),
+ ('h', 'ㄝˇㄑㄩˋ'),
+ ('i', 'ㄞˋ'),
+ ('j', 'ㄐㄟˋ'),
+ ('k', 'ㄎㄟˋ'),
+ ('l', 'ㄝˊㄛˋ'),
+ ('m', 'ㄝˊㄇㄨˋ'),
+ ('n', 'ㄣˉ'),
+ ('o', 'ㄡˉ'),
+ ('p', 'ㄆㄧˉ'),
+ ('q', 'ㄎㄧㄡˉ'),
+ ('r', 'ㄚˋ'),
+ ('s', 'ㄝˊㄙˋ'),
+ ('t', 'ㄊㄧˋ'),
+ ('u', 'ㄧㄡˉ'),
+ ('v', 'ㄨㄧˉ'),
+ ('w', 'ㄉㄚˋㄅㄨˋㄌㄧㄡˋ'),
+ ('x', 'ㄝˉㄎㄨˋㄙˋ'),
+ ('y', 'ㄨㄞˋ'),
+ ('z', 'ㄗㄟˋ')
+]]
+
+# List of (bopomofo, romaji) pairs:
+_bopomofo_to_romaji = [(re.compile('%s' % x[0]), x[1]) for x in [
+ ('ㄅㄛ', 'p⁼wo'),
+ ('ㄆㄛ', 'pʰwo'),
+ ('ㄇㄛ', 'mwo'),
+ ('ㄈㄛ', 'fwo'),
+ ('ㄅ', 'p⁼'),
+ ('ㄆ', 'pʰ'),
+ ('ㄇ', 'm'),
+ ('ㄈ', 'f'),
+ ('ㄉ', 't⁼'),
+ ('ㄊ', 'tʰ'),
+ ('ㄋ', 'n'),
+ ('ㄌ', 'l'),
+ ('ㄍ', 'k⁼'),
+ ('ㄎ', 'kʰ'),
+ ('ㄏ', 'h'),
+ ('ㄐ', 'ʧ⁼'),
+ ('ㄑ', 'ʧʰ'),
+ ('ㄒ', 'ʃ'),
+ ('ㄓ', 'ʦ`⁼'),
+ ('ㄔ', 'ʦ`ʰ'),
+ ('ㄕ', 's`'),
+ ('ㄖ', 'ɹ`'),
+ ('ㄗ', 'ʦ⁼'),
+ ('ㄘ', 'ʦʰ'),
+ ('ㄙ', 's'),
+ ('ㄚ', 'a'),
+ ('ㄛ', 'o'),
+ ('ㄜ', 'ə'),
+ ('ㄝ', 'e'),
+ ('ㄞ', 'ai'),
+ ('ㄟ', 'ei'),
+ ('ㄠ', 'au'),
+ ('ㄡ', 'ou'),
+ ('ㄧㄢ', 'yeNN'),
+ ('ㄢ', 'aNN'),
+ ('ㄧㄣ', 'iNN'),
+ ('ㄣ', 'əNN'),
+ ('ㄤ', 'aNg'),
+ ('ㄧㄥ', 'iNg'),
+ ('ㄨㄥ', 'uNg'),
+ ('ㄩㄥ', 'yuNg'),
+ ('ㄥ', 'əNg'),
+ ('ㄦ', 'əɻ'),
+ ('ㄧ', 'i'),
+ ('ㄨ', 'u'),
+ ('ㄩ', 'ɥ'),
+ ('ˉ', '→'),
+ ('ˊ', '↑'),
+ ('ˇ', '↓↑'),
+ ('ˋ', '↓'),
+ ('˙', ''),
+ (',', ','),
+ ('。', '.'),
+ ('!', '!'),
+ ('?', '?'),
+ ('—', '-')
+]]
+
+# List of (romaji, ipa) pairs:
+_romaji_to_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
+ ('ʃy', 'ʃ'),
+ ('ʧʰy', 'ʧʰ'),
+ ('ʧ⁼y', 'ʧ⁼'),
+ ('NN', 'n'),
+ ('Ng', 'ŋ'),
+ ('y', 'j'),
+ ('h', 'x')
+]]
+
+# List of (bopomofo, ipa) pairs:
+_bopomofo_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
+ ('ㄅㄛ', 'p⁼wo'),
+ ('ㄆㄛ', 'pʰwo'),
+ ('ㄇㄛ', 'mwo'),
+ ('ㄈㄛ', 'fwo'),
+ ('ㄅ', 'p⁼'),
+ ('ㄆ', 'pʰ'),
+ ('ㄇ', 'm'),
+ ('ㄈ', 'f'),
+ ('ㄉ', 't⁼'),
+ ('ㄊ', 'tʰ'),
+ ('ㄋ', 'n'),
+ ('ㄌ', 'l'),
+ ('ㄍ', 'k⁼'),
+ ('ㄎ', 'kʰ'),
+ ('ㄏ', 'x'),
+ ('ㄐ', 'tʃ⁼'),
+ ('ㄑ', 'tʃʰ'),
+ ('ㄒ', 'ʃ'),
+ ('ㄓ', 'ts`⁼'),
+ ('ㄔ', 'ts`ʰ'),
+ ('ㄕ', 's`'),
+ ('ㄖ', 'ɹ`'),
+ ('ㄗ', 'ts⁼'),
+ ('ㄘ', 'tsʰ'),
+ ('ㄙ', 's'),
+ ('ㄚ', 'a'),
+ ('ㄛ', 'o'),
+ ('ㄜ', 'ə'),
+ ('ㄝ', 'ɛ'),
+ ('ㄞ', 'aɪ'),
+ ('ㄟ', 'eɪ'),
+ ('ㄠ', 'ɑʊ'),
+ ('ㄡ', 'oʊ'),
+ ('ㄧㄢ', 'jɛn'),
+ ('ㄩㄢ', 'ɥæn'),
+ ('ㄢ', 'an'),
+ ('ㄧㄣ', 'in'),
+ ('ㄩㄣ', 'ɥn'),
+ ('ㄣ', 'ən'),
+ ('ㄤ', 'ɑŋ'),
+ ('ㄧㄥ', 'iŋ'),
+ ('ㄨㄥ', 'ʊŋ'),
+ ('ㄩㄥ', 'jʊŋ'),
+ ('ㄥ', 'əŋ'),
+ ('ㄦ', 'əɻ'),
+ ('ㄧ', 'i'),
+ ('ㄨ', 'u'),
+ ('ㄩ', 'ɥ'),
+ ('ˉ', '→'),
+ ('ˊ', '↑'),
+ ('ˇ', '↓↑'),
+ ('ˋ', '↓'),
+ ('˙', ''),
+ (',', ','),
+ ('。', '.'),
+ ('!', '!'),
+ ('?', '?'),
+ ('—', '-')
+]]
+
+# List of (bopomofo, ipa2) pairs:
+_bopomofo_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
+ ('ㄅㄛ', 'pwo'),
+ ('ㄆㄛ', 'pʰwo'),
+ ('ㄇㄛ', 'mwo'),
+ ('ㄈㄛ', 'fwo'),
+ ('ㄅ', 'p'),
+ ('ㄆ', 'pʰ'),
+ ('ㄇ', 'm'),
+ ('ㄈ', 'f'),
+ ('ㄉ', 't'),
+ ('ㄊ', 'tʰ'),
+ ('ㄋ', 'n'),
+ ('ㄌ', 'l'),
+ ('ㄍ', 'k'),
+ ('ㄎ', 'kʰ'),
+ ('ㄏ', 'h'),
+ ('ㄐ', 'tɕ'),
+ ('ㄑ', 'tɕʰ'),
+ ('ㄒ', 'ɕ'),
+ ('ㄓ', 'tʂ'),
+ ('ㄔ', 'tʂʰ'),
+ ('ㄕ', 'ʂ'),
+ ('ㄖ', 'ɻ'),
+ ('ㄗ', 'ts'),
+ ('ㄘ', 'tsʰ'),
+ ('ㄙ', 's'),
+ ('ㄚ', 'a'),
+ ('ㄛ', 'o'),
+ ('ㄜ', 'ɤ'),
+ ('ㄝ', 'ɛ'),
+ ('ㄞ', 'aɪ'),
+ ('ㄟ', 'eɪ'),
+ ('ㄠ', 'ɑʊ'),
+ ('ㄡ', 'oʊ'),
+ ('ㄧㄢ', 'jɛn'),
+ ('ㄩㄢ', 'yæn'),
+ ('ㄢ', 'an'),
+ ('ㄧㄣ', 'in'),
+ ('ㄩㄣ', 'yn'),
+ ('ㄣ', 'ən'),
+ ('ㄤ', 'ɑŋ'),
+ ('ㄧㄥ', 'iŋ'),
+ ('ㄨㄥ', 'ʊŋ'),
+ ('ㄩㄥ', 'jʊŋ'),
+ ('ㄥ', 'ɤŋ'),
+ ('ㄦ', 'əɻ'),
+ ('ㄧ', 'i'),
+ ('ㄨ', 'u'),
+ ('ㄩ', 'y'),
+ ('ˉ', '˥'),
+ ('ˊ', '˧˥'),
+ ('ˇ', '˨˩˦'),
+ ('ˋ', '˥˩'),
+ ('˙', ''),
+ (',', ','),
+ ('。', '.'),
+ ('!', '!'),
+ ('?', '?'),
+ ('—', '-')
+]]
+
+
+def number_to_chinese(text):
+ numbers = re.findall(r'\d+(?:\.?\d+)?', text)
+ for number in numbers:
+ text = text.replace(number, cn2an.an2cn(number), 1)
+ return text
+
+
+def chinese_to_bopomofo(text):
+ from pypinyin import lazy_pinyin, BOPOMOFO
+ text = text.replace('、', ',').replace(';', ',').replace(':', ',')
+ words = jieba.lcut(text, cut_all=False)
+ text = ''
+ for word in words:
+ bopomofos = lazy_pinyin(word, BOPOMOFO)
+ if not re.search('[\u4e00-\u9fff]', word):
+ text += word
+ continue
+ for i in range(len(bopomofos)):
+ bopomofos[i] = re.sub(r'([\u3105-\u3129])$', r'\1ˉ', bopomofos[i])
+ if text != '':
+ text += ' '
+ text += ''.join(bopomofos)
+ return text
+
+
+def latin_to_bopomofo(text):
+ for regex, replacement in _latin_to_bopomofo:
+ text = re.sub(regex, replacement, text)
+ return text
+
+
+def bopomofo_to_romaji(text):
+ for regex, replacement in _bopomofo_to_romaji:
+ text = re.sub(regex, replacement, text)
+ return text
+
+
+def bopomofo_to_ipa(text):
+ for regex, replacement in _bopomofo_to_ipa:
+ text = re.sub(regex, replacement, text)
+ return text
+
+
+def bopomofo_to_ipa2(text):
+ for regex, replacement in _bopomofo_to_ipa2:
+ text = re.sub(regex, replacement, text)
+ return text
+
+
+def chinese_to_romaji(text):
+ text = number_to_chinese(text)
+ text = chinese_to_bopomofo(text)
+ text = latin_to_bopomofo(text)
+ text = bopomofo_to_romaji(text)
+ text = re.sub('i([aoe])', r'y\1', text)
+ text = re.sub('u([aoəe])', r'w\1', text)
+ text = re.sub('([ʦsɹ]`[⁼ʰ]?)([→↓↑ ]+|$)',
+ r'\1ɹ`\2', text).replace('ɻ', 'ɹ`')
+ text = re.sub('([ʦs][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text)
+ return text
+
+
+def chinese_to_lazy_ipa(text):
+ text = chinese_to_romaji(text)
+ for regex, replacement in _romaji_to_ipa:
+ text = re.sub(regex, replacement, text)
+ return text
+
+
+def chinese_to_ipa(text):
+ text = number_to_chinese(text)
+ text = chinese_to_bopomofo(text)
+ text = latin_to_bopomofo(text)
+ text = bopomofo_to_ipa(text)
+ text = re.sub('i([aoe])', r'j\1', text)
+ text = re.sub('u([aoəe])', r'w\1', text)
+ text = re.sub('([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)',
+ r'\1ɹ`\2', text).replace('ɻ', 'ɹ`')
+ text = re.sub('([s][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text)
+ return text
+
+
+def chinese_to_ipa2(text):
+ text = number_to_chinese(text)
+ text = chinese_to_bopomofo(text)
+ text = latin_to_bopomofo(text)
+ text = bopomofo_to_ipa2(text)
+ text = re.sub(r'i([aoe])', r'j\1', text)
+ text = re.sub(r'u([aoəe])', r'w\1', text)
+ text = re.sub(r'([ʂɹ]ʰ?)([˩˨˧˦˥ ]+|$)', r'\1ʅ\2', text)
+ text = re.sub(r'(sʰ?)([˩˨˧˦˥ ]+|$)', r'\1ɿ\2', text)
+ return text
diff --git a/utils/g2p/symbols.py b/utils/g2p/symbols.py
new file mode 100644
index 0000000000000000000000000000000000000000..789e9df25d3d93d1976ef22d15d77f51d170ed00
--- /dev/null
+++ b/utils/g2p/symbols.py
@@ -0,0 +1,76 @@
+'''
+Defines the set of symbols used in text input to the model.
+'''
+
+# japanese_cleaners
+# _pad = '_'
+# _punctuation = ',.!?-'
+# _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧ↓↑ '
+
+
+'''# japanese_cleaners2
+_pad = '_'
+_punctuation = ',.!?-~…'
+_letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧʦ↓↑ '
+'''
+
+
+'''# korean_cleaners
+_pad = '_'
+_punctuation = ',.!?…~'
+_letters = 'ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ '
+'''
+
+'''# chinese_cleaners
+_pad = '_'
+_punctuation = ',。!?—…'
+_letters = 'ㄅㄆㄇㄈㄉㄊㄋㄌㄍㄎㄏㄐㄑㄒㄓㄔㄕㄖㄗㄘㄙㄚㄛㄜㄝㄞㄟㄠㄡㄢㄣㄤㄥㄦㄧㄨㄩˉˊˇˋ˙ '
+'''
+
+# # zh_ja_mixture_cleaners
+# _pad = '_'
+# _punctuation = ',.!?-~…'
+# _letters = 'AEINOQUabdefghijklmnoprstuvwyzʃʧʦɯɹəɥ⁼ʰ`→↓↑ '
+
+
+'''# sanskrit_cleaners
+_pad = '_'
+_punctuation = '।'
+_letters = 'ँंःअआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलळवशषसहऽािीुूृॄेैोौ्ॠॢ '
+'''
+
+'''# cjks_cleaners
+_pad = '_'
+_punctuation = ',.!?-~…'
+_letters = 'NQabdefghijklmnopstuvwxyzʃʧʥʦɯɹəɥçɸɾβŋɦː⁼ʰ`^#*=→↓↑ '
+'''
+
+'''# thai_cleaners
+_pad = '_'
+_punctuation = '.!? '
+_letters = 'กขฃคฆงจฉชซฌญฎฏฐฑฒณดตถทธนบปผฝพฟภมยรฤลวศษสหฬอฮฯะัาำิีึืุูเแโใไๅๆ็่้๊๋์'
+'''
+
+# # cjke_cleaners2
+_pad = '_'
+_punctuation = ',.!?-~…'
+_letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ '
+
+
+'''# shanghainese_cleaners
+_pad = '_'
+_punctuation = ',.!?…'
+_letters = 'abdfghiklmnopstuvyzøŋȵɑɔɕəɤɦɪɿʑʔʰ̩̃ᴀᴇ15678 '
+'''
+
+'''# chinese_dialect_cleaners
+_pad = '_'
+_punctuation = ',.!?~…─'
+_letters = '#Nabdefghijklmnoprstuvwxyzæçøŋœȵɐɑɒɓɔɕɗɘəɚɛɜɣɤɦɪɭɯɵɷɸɻɾɿʂʅʊʋʌʏʑʔʦʮʰʷˀː˥˦˧˨˩̥̩̃̚ᴀᴇ↑↓∅ⱼ '
+'''
+
+# Export all symbols:
+symbols = [_pad] + list(_punctuation) + list(_letters)
+
+# Special symbol ids
+SPACE_ID = symbols.index(" ")
diff --git a/utils/generation.py b/utils/generation.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0952dba48146ba3be2f3f1133b9f0640ad30f3c
--- /dev/null
+++ b/utils/generation.py
@@ -0,0 +1,257 @@
+import os
+import torch
+import gdown
+import logging
+import psutil
+import langid
+langid.set_languages(['en', 'zh', 'ja'])
+
+import pathlib
+import platform
+if platform.system().lower() == 'windows':
+ temp = pathlib.PosixPath
+ pathlib.PosixPath = pathlib.WindowsPath
+elif platform.system().lower() == 'linux':
+ temp = pathlib.WindowsPath
+ pathlib.WindowsPath = pathlib.PosixPath
+
+import numpy as np
+from data.tokenizer import (
+ AudioTokenizer,
+ tokenize_audio,
+)
+from data.collation import get_text_token_collater
+from models.vallex import VALLE
+from utils.g2p import PhonemeBpeTokenizer
+from utils.sentence_cutter import split_text_into_sentences
+
+from macros import *
+
+device = torch.device("cpu")
+if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+url = 'https://drive.google.com/file/d/10gdQWvP-K_e1undkvv0p2b7SU6I4Egyl/view?usp=sharing'
+
+checkpoints_dir = "./checkpoints/"
+
+model_checkpoint_name = "vallex-checkpoint.pt"
+
+model = None
+
+codec = None
+
+text_tokenizer = PhonemeBpeTokenizer(tokenizer_path="./utils/g2p/bpe_69.json")
+text_collater = get_text_token_collater()
+
+def preload_models():
+ global model, codec
+ if not os.path.exists(checkpoints_dir): os.mkdir(checkpoints_dir)
+ if not os.path.exists(os.path.join(checkpoints_dir, model_checkpoint_name)):
+ gdown.download(id="10gdQWvP-K_e1undkvv0p2b7SU6I4Egyl", output=os.path.join(checkpoints_dir, model_checkpoint_name), quiet=False)
+ # VALL-E
+ model = VALLE(
+ N_DIM,
+ NUM_HEAD,
+ NUM_LAYERS,
+ norm_first=True,
+ add_prenet=False,
+ prefix_mode=PREFIX_MODE,
+ share_embedding=True,
+ nar_scale_factor=1.0,
+ prepend_bos=True,
+ num_quantizers=NUM_QUANTIZERS,
+ ).to(device)
+ checkpoint = torch.load(os.path.join(checkpoints_dir, model_checkpoint_name), map_location='cpu')
+ missing_keys, unexpected_keys = model.load_state_dict(
+ checkpoint["model"], strict=True
+ )
+ assert not missing_keys
+ model.eval()
+
+ # Encodec
+ codec = AudioTokenizer(device)
+
+@torch.no_grad()
+def generate_audio(text, prompt=None, language='auto', accent='no-accent'):
+ global model, codec, text_tokenizer, text_collater
+ text = text.replace("\n", "").strip(" ")
+ # detect language
+ if language == "auto":
+ language = langid.classify(text)[0]
+ lang_token = lang2token[language]
+ lang = token2lang[lang_token]
+ text = lang_token + text + lang_token
+
+ # load prompt
+ if prompt is not None:
+ prompt_path = prompt
+ if not os.path.exists(prompt_path):
+ prompt_path = "./presets/" + prompt + ".npz"
+ if not os.path.exists(prompt_path):
+ prompt_path = "./customs/" + prompt + ".npz"
+ if not os.path.exists(prompt_path):
+ raise ValueError(f"Cannot find prompt {prompt}")
+ prompt_data = np.load(prompt_path)
+ audio_prompts = prompt_data['audio_tokens']
+ text_prompts = prompt_data['text_tokens']
+ lang_pr = prompt_data['lang_code']
+ lang_pr = code2lang[int(lang_pr)]
+
+ # numpy to tensor
+ audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device)
+ text_prompts = torch.tensor(text_prompts).type(torch.int32)
+ else:
+ audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device)
+ text_prompts = torch.zeros([1, 0]).type(torch.int32)
+ lang_pr = lang if lang != 'mix' else 'en'
+
+ enroll_x_lens = text_prompts.shape[-1]
+ logging.info(f"synthesize text: {text}")
+ phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
+ text_tokens, text_tokens_lens = text_collater(
+ [
+ phone_tokens
+ ]
+ )
+ text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
+ text_tokens_lens += enroll_x_lens
+ # accent control
+ lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
+ encoded_frames = model.inference(
+ text_tokens.to(device),
+ text_tokens_lens.to(device),
+ audio_prompts,
+ enroll_x_lens=enroll_x_lens,
+ top_k=-100,
+ temperature=1,
+ prompt_language=lang_pr,
+ text_language=langs if accent == "no-accent" else lang,
+ )
+ samples = codec.decode(
+ [(encoded_frames.transpose(2, 1), None)]
+ )
+
+ return samples[0][0].cpu().numpy()
+
+@torch.no_grad()
+def generate_audio_from_long_text(text, prompt=None, language='auto', accent='no-accent', mode='sliding-window'):
+ """
+ For long audio generation, two modes are available.
+ fixed-prompt: This mode will keep using the same prompt the user has provided, and generate audio sentence by sentence.
+ sliding-window: This mode will use the last sentence as the prompt for the next sentence, but has some concern on speaker maintenance.
+ """
+ global model, codec, text_tokenizer, text_collater
+ if prompt is None or prompt == "":
+ mode = 'sliding-window' # If no prompt is given, use sliding-window mode
+ sentences = split_text_into_sentences(text)
+ # detect language
+ if language == "auto":
+ language = langid.classify(text)[0]
+
+ # if initial prompt is given, encode it
+ if prompt is not None and prompt != "":
+ prompt_path = prompt
+ if not os.path.exists(prompt_path):
+ prompt_path = "./presets/" + prompt + ".npz"
+ if not os.path.exists(prompt_path):
+ prompt_path = "./customs/" + prompt + ".npz"
+ if not os.path.exists(prompt_path):
+ raise ValueError(f"Cannot find prompt {prompt}")
+ prompt_data = np.load(prompt_path)
+ audio_prompts = prompt_data['audio_tokens']
+ text_prompts = prompt_data['text_tokens']
+ lang_pr = prompt_data['lang_code']
+ lang_pr = code2lang[int(lang_pr)]
+
+ # numpy to tensor
+ audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device)
+ text_prompts = torch.tensor(text_prompts).type(torch.int32)
+ else:
+ audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device)
+ text_prompts = torch.zeros([1, 0]).type(torch.int32)
+ lang_pr = language if language != 'mix' else 'en'
+ if mode == 'fixed-prompt':
+ complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device)
+ for text in sentences:
+ text = text.replace("\n", "").strip(" ")
+ if text == "":
+ continue
+ lang_token = lang2token[language]
+ lang = token2lang[lang_token]
+ text = lang_token + text + lang_token
+
+ enroll_x_lens = text_prompts.shape[-1]
+ logging.info(f"synthesize text: {text}")
+ phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
+ text_tokens, text_tokens_lens = text_collater(
+ [
+ phone_tokens
+ ]
+ )
+ text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
+ text_tokens_lens += enroll_x_lens
+ # accent control
+ lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
+ encoded_frames = model.inference(
+ text_tokens.to(device),
+ text_tokens_lens.to(device),
+ audio_prompts,
+ enroll_x_lens=enroll_x_lens,
+ top_k=-100,
+ temperature=1,
+ prompt_language=lang_pr,
+ text_language=langs if accent == "no-accent" else lang,
+ )
+ complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1)
+ samples = codec.decode(
+ [(complete_tokens, None)]
+ )
+ return samples[0][0].cpu().numpy()
+ elif mode == "sliding-window":
+ complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device)
+ original_audio_prompts = audio_prompts
+ original_text_prompts = text_prompts
+ for text in sentences:
+ text = text.replace("\n", "").strip(" ")
+ if text == "":
+ continue
+ lang_token = lang2token[language]
+ lang = token2lang[lang_token]
+ text = lang_token + text + lang_token
+
+ enroll_x_lens = text_prompts.shape[-1]
+ logging.info(f"synthesize text: {text}")
+ phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
+ text_tokens, text_tokens_lens = text_collater(
+ [
+ phone_tokens
+ ]
+ )
+ text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
+ text_tokens_lens += enroll_x_lens
+ # accent control
+ lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
+ encoded_frames = model.inference(
+ text_tokens.to(device),
+ text_tokens_lens.to(device),
+ audio_prompts,
+ enroll_x_lens=enroll_x_lens,
+ top_k=-100,
+ temperature=1,
+ prompt_language=lang_pr,
+ text_language=langs if accent == "no-accent" else lang,
+ )
+ complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1)
+ if torch.rand(1) < 0.5:
+ audio_prompts = encoded_frames[:, :, -NUM_QUANTIZERS:]
+ text_prompts = text_tokens[:, enroll_x_lens:]
+ else:
+ audio_prompts = original_audio_prompts
+ text_prompts = original_text_prompts
+ samples = codec.decode(
+ [(complete_tokens, None)]
+ )
+ return samples[0][0].cpu().numpy()
+ else:
+ raise ValueError(f"No such mode {mode}")
diff --git a/utils/prompt_making.py b/utils/prompt_making.py
new file mode 100644
index 0000000000000000000000000000000000000000..93e4a3d647052df4899253fea41be22f09e006b8
--- /dev/null
+++ b/utils/prompt_making.py
@@ -0,0 +1,115 @@
+import os
+import torch
+import torchaudio
+import logging
+import langid
+import whisper
+langid.set_languages(['en', 'zh', 'ja'])
+
+import numpy as np
+from data.tokenizer import (
+ AudioTokenizer,
+ tokenize_audio,
+)
+from data.collation import get_text_token_collater
+from utils.g2p import PhonemeBpeTokenizer
+
+from macros import *
+
+text_tokenizer = PhonemeBpeTokenizer(tokenizer_path="./utils/g2p/bpe_69.json")
+text_collater = get_text_token_collater()
+
+device = torch.device("cpu")
+if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+codec = AudioTokenizer(device)
+
+whisper_model = None
+
+@torch.no_grad()
+def transcribe_one(model, audio_path):
+ # load audio and pad/trim it to fit 30 seconds
+ audio = whisper.load_audio(audio_path)
+ audio = whisper.pad_or_trim(audio)
+
+ # make log-Mel spectrogram and move to the same device as the model
+ mel = whisper.log_mel_spectrogram(audio).to(model.device)
+
+ # detect the spoken language
+ _, probs = model.detect_language(mel)
+ print(f"Detected language: {max(probs, key=probs.get)}")
+ lang = max(probs, key=probs.get)
+ # decode the audio
+ options = whisper.DecodingOptions(temperature=1.0, best_of=5, fp16=False if device == torch.device("cpu") else True, sample_len=150)
+ result = whisper.decode(model, mel, options)
+
+ # print the recognized text
+ print(result.text)
+
+ text_pr = result.text
+ if text_pr.strip(" ")[-1] not in "?!.,。,?!。、":
+ text_pr += "."
+ return lang, text_pr
+
+def make_prompt(name, audio_prompt_path, transcript=None):
+ global model, text_collater, text_tokenizer, codec
+ wav_pr, sr = torchaudio.load(audio_prompt_path)
+ # check length
+ if wav_pr.size(-1) / sr > 15:
+ raise ValueError(f"Prompt too long, expect length below 15 seconds, got {wav_pr / sr} seconds.")
+ if wav_pr.size(0) == 2:
+ wav_pr = wav_pr.mean(0, keepdim=True)
+ text_pr, lang_pr = make_transcript(name, wav_pr, sr, transcript)
+
+ # tokenize audio
+ encoded_frames = tokenize_audio(codec, (wav_pr, sr))
+ audio_tokens = encoded_frames[0][0].transpose(2, 1).cpu().numpy()
+
+ # tokenize text
+ phonemes, langs = text_tokenizer.tokenize(text=f"{text_pr}".strip())
+ text_tokens, enroll_x_lens = text_collater(
+ [
+ phonemes
+ ]
+ )
+
+ message = f"Detected language: {lang_pr}\n Detected text {text_pr}\n"
+
+ # save as npz file
+ save_path = os.path.join("./customs/", f"{name}.npz")
+ np.savez(save_path, audio_tokens=audio_tokens, text_tokens=text_tokens, lang_code=lang2code[lang_pr])
+ logging.info(f"Successful. Prompt saved to {save_path}")
+
+
+def make_transcript(name, wav, sr, transcript=None):
+
+ if not isinstance(wav, torch.FloatTensor):
+ wav = torch.tensor(wav)
+ if wav.abs().max() > 1:
+ wav /= wav.abs().max()
+ if wav.size(-1) == 2:
+ wav = wav.mean(-1, keepdim=False)
+ if wav.ndim == 1:
+ wav = wav.unsqueeze(0)
+ assert wav.ndim and wav.size(0) == 1
+ if transcript is None or transcript == "":
+ logging.info("Transcript not given, using Whisper...")
+ global whisper_model
+ if whisper_model is None:
+ whisper_model = whisper.load_model("medium")
+ whisper_model.to(device)
+ torchaudio.save(f"./prompts/{name}.wav", wav, sr)
+ lang, text = transcribe_one(whisper_model, f"./prompts/{name}.wav")
+ lang_token = lang2token[lang]
+ text = lang_token + text + lang_token
+ os.remove(f"./prompts/{name}.wav")
+ whisper_model.cpu()
+ else:
+ text = transcript
+ lang, _ = langid.classify(text)
+ lang_token = lang2token[lang]
+ text = lang_token + text + lang_token
+
+ torch.cuda.empty_cache()
+ return text, lang
\ No newline at end of file
diff --git a/utils/sentence_cutter.py b/utils/sentence_cutter.py
new file mode 100644
index 0000000000000000000000000000000000000000..02b8b141b5b139709fb477653a5c3804cd74fb2d
--- /dev/null
+++ b/utils/sentence_cutter.py
@@ -0,0 +1,43 @@
+import nltk
+import jieba
+import sudachipy
+import langid
+nltk.download('punkt')
+langid.set_languages(['en', 'zh', 'ja'])
+
+def split_text_into_sentences(text):
+ if langid.classify(text)[0] == "en":
+ sentences = nltk.tokenize.sent_tokenize(text)
+
+ return sentences
+ elif langid.classify(text)[0] == "zh":
+ sentences = []
+ segs = jieba.cut(text, cut_all=False)
+ segs = list(segs)
+ start = 0
+ for i, seg in enumerate(segs):
+ if seg in ["。", "!", "?", "……"]:
+ sentences.append("".join(segs[start:i + 1]))
+ start = i + 1
+ if start < len(segs):
+ sentences.append("".join(segs[start:]))
+
+ return sentences
+ elif langid.classify(text)[0] == "ja":
+ sentences = []
+ tokenizer = sudachipy.Dictionary().create()
+ tokens = tokenizer.tokenize(text)
+ current_sentence = ""
+
+ for token in tokens:
+ current_sentence += token.surface()
+ if token.part_of_speech()[0] == "補助記号" and token.part_of_speech()[1] == "句点":
+ sentences.append(current_sentence)
+ current_sentence = ""
+
+ if current_sentence:
+ sentences.append(current_sentence)
+
+ return sentences
+
+ raise RuntimeError("It is impossible to reach here.")
\ No newline at end of file
diff --git a/utils/symbol_table.py b/utils/symbol_table.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a86010a76280576f85490641623dbb27559aa99
--- /dev/null
+++ b/utils/symbol_table.py
@@ -0,0 +1,287 @@
+# Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang)
+#
+# See ../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from dataclasses import field
+from typing import Dict
+from typing import Generic
+from typing import List
+from typing import Optional
+from typing import TypeVar
+from typing import Union
+
+Symbol = TypeVar('Symbol')
+
+
+# Disable __repr__ otherwise it could freeze e.g. Jupyter.
+@dataclass(repr=False)
+class SymbolTable(Generic[Symbol]):
+ '''SymbolTable that maps symbol IDs, found on the FSA arcs to
+ actual objects. These objects can be arbitrary Python objects
+ that can serve as keys in a dictionary (i.e. they need to be
+ hashable and immutable).
+
+ The SymbolTable can only be read to/written from disk if the
+ symbols are strings.
+ '''
+ _id2sym: Dict[int, Symbol] = field(default_factory=dict)
+ '''Map an integer to a symbol.
+ '''
+
+ _sym2id: Dict[Symbol, int] = field(default_factory=dict)
+ '''Map a symbol to an integer.
+ '''
+
+ _next_available_id: int = 1
+ '''A helper internal field that helps adding new symbols
+ to the table efficiently.
+ '''
+
+ eps: Symbol = ''
+ '''Null symbol, always mapped to index 0.
+ '''
+
+ def __post_init__(self):
+ for idx, sym in self._id2sym.items():
+ assert self._sym2id[sym] == idx
+ assert idx >= 0
+
+ for sym, idx in self._sym2id.items():
+ assert idx >= 0
+ assert self._id2sym[idx] == sym
+
+ if 0 not in self._id2sym:
+ self._id2sym[0] = self.eps
+ self._sym2id[self.eps] = 0
+ else:
+ assert self._id2sym[0] == self.eps
+ assert self._sym2id[self.eps] == 0
+
+ self._next_available_id = max(self._id2sym) + 1
+
+ @staticmethod
+ def from_str(s: str) -> 'SymbolTable':
+ '''Build a symbol table from a string.
+
+ The string consists of lines. Every line has two fields separated
+ by space(s), tab(s) or both. The first field is the symbol and the
+ second the integer id of the symbol.
+
+ Args:
+ s:
+ The input string with the format described above.
+ Returns:
+ An instance of :class:`SymbolTable`.
+ '''
+ id2sym: Dict[int, str] = dict()
+ sym2id: Dict[str, int] = dict()
+
+ for line in s.split('\n'):
+ fields = line.split()
+ if len(fields) == 0:
+ continue # skip empty lines
+ assert len(fields) == 2, \
+ f'Expect a line with 2 fields. Given: {len(fields)}'
+ sym, idx = fields[0], int(fields[1])
+ assert sym not in sym2id, f'Duplicated symbol {sym}'
+ assert idx not in id2sym, f'Duplicated id {idx}'
+ id2sym[idx] = sym
+ sym2id[sym] = idx
+
+ eps = id2sym.get(0, '')
+
+ return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=eps)
+
+ @staticmethod
+ def from_file(filename: str) -> 'SymbolTable':
+ '''Build a symbol table from file.
+
+ Every line in the symbol table file has two fields separated by
+ space(s), tab(s) or both. The following is an example file:
+
+ .. code-block::
+
+ 0
+ a 1
+ b 2
+ c 3
+
+ Args:
+ filename:
+ Name of the symbol table file. Its format is documented above.
+
+ Returns:
+ An instance of :class:`SymbolTable`.
+
+ '''
+ with open(filename, 'r', encoding='utf-8') as f:
+ return SymbolTable.from_str(f.read().strip())
+
+ def to_str(self) -> str:
+ '''
+ Returns:
+ Return a string representation of this object. You can pass
+ it to the method ``from_str`` to recreate an identical object.
+ '''
+ s = ''
+ for idx, symbol in sorted(self._id2sym.items()):
+ s += f'{symbol} {idx}\n'
+ return s
+
+ def to_file(self, filename: str):
+ '''Serialize the SymbolTable to a file.
+
+ Every line in the symbol table file has two fields separated by
+ space(s), tab(s) or both. The following is an example file:
+
+ .. code-block::
+
+ 0
+ a 1
+ b 2
+ c 3
+
+ Args:
+ filename:
+ Name of the symbol table file. Its format is documented above.
+ '''
+ with open(filename, 'w') as f:
+ for idx, symbol in sorted(self._id2sym.items()):
+ print(symbol, idx, file=f)
+
+ def add(self, symbol: Symbol, index: Optional[int] = None) -> int:
+ '''Add a new symbol to the SymbolTable.
+
+ Args:
+ symbol:
+ The symbol to be added.
+ index:
+ Optional int id to which the symbol should be assigned.
+ If it is not available, a ValueError will be raised.
+
+ Returns:
+ The int id to which the symbol has been assigned.
+ '''
+ # Already in the table? Return its ID.
+ if symbol in self._sym2id:
+ return self._sym2id[symbol]
+ # Specific ID not provided - use next available.
+ if index is None:
+ index = self._next_available_id
+ # Specific ID provided but not available.
+ if index in self._id2sym:
+ raise ValueError(f"Cannot assign id '{index}' to '{symbol}' - "
+ f"already occupied by {self._id2sym[index]}")
+ self._sym2id[symbol] = index
+ self._id2sym[index] = symbol
+
+ # Update next available ID if needed
+ if self._next_available_id <= index:
+ self._next_available_id = index + 1
+
+ return index
+
+ def get(self, k: Union[int, Symbol]) -> Union[Symbol, int]:
+ '''Get a symbol for an id or get an id for a symbol
+
+ Args:
+ k:
+ If it is an id, it tries to find the symbol corresponding
+ to the id; if it is a symbol, it tries to find the id
+ corresponding to the symbol.
+
+ Returns:
+ An id or a symbol depending on the given `k`.
+ '''
+ if isinstance(k, int):
+ return self._id2sym[k]
+ else:
+ return self._sym2id[k]
+
+ def merge(self, other: 'SymbolTable') -> 'SymbolTable':
+ '''Create a union of two SymbolTables.
+ Raises an AssertionError if the same IDs are occupied by
+ different symbols.
+
+ Args:
+ other:
+ A symbol table to merge with ``self``.
+
+ Returns:
+ A new symbol table.
+ '''
+ self._check_compatible(other)
+
+ id2sym = {**self._id2sym, **other._id2sym}
+ sym2id = {**self._sym2id, **other._sym2id}
+
+ return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=self.eps)
+
+ def _check_compatible(self, other: 'SymbolTable') -> None:
+ # Epsilon compatibility
+ assert self.eps == other.eps, f'Mismatched epsilon symbol: ' \
+ f'{self.eps} != {other.eps}'
+ # IDs compatibility
+ common_ids = set(self._id2sym).intersection(other._id2sym)
+ for idx in common_ids:
+ assert self[idx] == other[idx], f'ID conflict for id: {idx}, ' \
+ f'self[idx] = "{self[idx]}", ' \
+ f'other[idx] = "{other[idx]}"'
+ # Symbols compatibility
+ common_symbols = set(self._sym2id).intersection(other._sym2id)
+ for sym in common_symbols:
+ assert self[sym] == other[sym], f'ID conflict for id: {sym}, ' \
+ f'self[sym] = "{self[sym]}", ' \
+ f'other[sym] = "{other[sym]}"'
+
+ def __getitem__(self, item: Union[int, Symbol]) -> Union[Symbol, int]:
+ return self.get(item)
+
+ def __contains__(self, item: Union[int, Symbol]) -> bool:
+ if isinstance(item, int):
+ return item in self._id2sym
+ else:
+ return item in self._sym2id
+
+ def __len__(self) -> int:
+ return len(self._id2sym)
+
+ def __eq__(self, other: 'SymbolTable') -> bool:
+ if len(self) != len(other):
+ return False
+
+ for s in self.symbols:
+ if self[s] != other[s]:
+ return False
+
+ return True
+
+ @property
+ def ids(self) -> List[int]:
+ '''Returns a list of integer IDs corresponding to the symbols.
+ '''
+ ans = list(self._id2sym.keys())
+ ans.sort()
+ return ans
+
+ @property
+ def symbols(self) -> List[Symbol]:
+ '''Returns a list of symbols (e.g., strings) corresponding to
+ the integer IDs.
+ '''
+ ans = list(self._sym2id.keys())
+ ans.sort()
+ return ans