SingingSDS / svs_utils.py
jhansss's picture
Add dummy batch for model warmup in svs_warmup function
017498a
raw
history blame
14.3 kB
import json
import random
import numpy as np
from espnet2.bin.svs_inference import SingingGenerate
from espnet_model_zoo.downloader import ModelDownloader
from util import get_pinyin, get_tokenizer, postprocess_phn, preprocess_input
from kanjiconv import KanjiConv
import unicodedata
kanji_to_kana = KanjiConv()
def svs_warmup(config):
"""
What: module loading, and model loading
Input: config dict/namespace (e.g., model path, cache dir, device, language, possibly speaker selection)
Return: the inference prototype function (which creates pitch/duration and runs model-specific inference)
"""
if config.model_path.startswith("espnet"):
espnet_downloader = ModelDownloader(config.cache_dir)
downloaded = espnet_downloader.download_and_unpack(config.model_path)
model = SingingGenerate(
train_config=downloaded["train_config"],
model_file=downloaded["model_file"],
device=config.device,
)
dummy_batch = {
"score": (
75, # tempo
[
(0.0, 0.25, "r_en", 63.0, "r_en"),
(0.25, 0.5, "—", 63.0, "en"),
],
),
"text": "r en en",
}
model(
dummy_batch,
lids=np.array([2]),
spembs=np.load("resource/singer/singer_embedding_ace-2.npy"),
) # warmup
else:
raise NotImplementedError(f"Model {config.model_path} not supported")
return model
yoon_map = {
"ぁ": "あ", "ぃ": "い", "ぅ": "う", "ぇ": "え", "ぉ": "お",
"ゃ": "や", "ゅ": "ゆ", "ょ": "よ", "ゎ": "わ"
}
def replace_chouonpu(hiragana_text):
""" process「ー」since the previous packages didn't support """
vowels = {
"あ": "あ", "い": "い", "う": "う", "え": "え", "お": "う",
"か": "あ", "き": "い", "く": "う", "け": "え", "こ": "う",
"さ": "あ", "し": "い", "す": "う", "せ": "え", "そ": "う",
"た": "あ", "ち": "い", "つ": "う", "て": "え", "と": "う",
"な": "あ", "に": "い", "ぬ": "う", "ね": "え", "の": "う",
"は": "あ", "ひ": "い", "ふ": "う", "へ": "え", "ほ": "う",
"ま": "あ", "み": "い", "む": "う", "め": "え", "も": "う",
"や": "あ", "ゆ": "う", "よ": "う",
"ら": "あ", "り": "い", "る": "う", "れ": "え", "ろ": "う",
"わ": "あ", "を": "う",
}
new_text = []
for i, char in enumerate(hiragana_text):
if char == "ー" and i > 0:
prev_char = new_text[-1]
if prev_char in yoon_map:
prev_char = yoon_map[prev_char]
new_text.append(vowels.get(prev_char, prev_char))
else:
new_text.append(char)
return "".join(new_text)
def is_small_kana(kana): # ょ True よ False
for char in kana:
name = unicodedata.name(char, "")
if "SMALL" in name:
return True
return False
def kanji_to_SVSDictKana(text):
hiragana_text = kanji_to_kana.to_hiragana(text.replace(" ", ""))
hiragana_text_wl = replace_chouonpu(hiragana_text).split(" ") # list
# print(f'debug -- hiragana_text {hiragana_text_wl}')
final_ls = []
for subword in hiragana_text_wl:
sl_prev = 0
for i in range(len(subword)-1):
if sl_prev>=len(subword)-1:
break
sl = sl_prev + 1
if subword[sl] in yoon_map:
final_ls.append(subword[sl_prev:sl+1])
sl_prev+=2
else:
final_ls.append(subword[sl_prev])
sl_prev+=1
final_ls.append(subword[sl_prev])
# final_str = " ".join(final_ls)
return final_ls
def svs_text_preprocessor(model_path, texts, lang):
"""
Input:
- model_path (str), for getting the corresponding tokenizer
- texts (str), in Chinese character or Japanese character
- lang (str), language label jp/zh, input if is not espnet model
Output:
- lyric_ls (lyric list), each element as 'k@zhe@zh'
- sybs (phn w/ _ list), each element as 'k@zh_e@zh'
- labels (phn w/o _ list), each element as 'k@zh'
"""
fs = 44100
if texts is None:
raise ValueError("texts is None when calling svs_text_preprocessor")
# preprocess
if lang == "zh":
texts = preprocess_input(texts, "")
text_list = get_pinyin(texts)
elif lang == "jp":
text_list = kanji_to_SVSDictKana(texts)
# texts = preprocess_input(texts, "")
# text_list = list(texts)
# text to phoneme
tokenizer = get_tokenizer(model_path, lang)
sybs = [] # phoneme list
for text in text_list:
if text == "AP" or text == "SP":
rev = [text]
elif text == "-" or text == "——":
rev = [text]
else:
rev = tokenizer(text)
if rev == False:
return (fs, np.array([0.0])), f"Error: text `{text}` is invalid!"
rev = postprocess_phn(rev, model_path, lang)
phns = "_".join(rev)
sybs.append(phns)
lyric_ls = []
labels = []
pre_phn = ""
for phns in sybs:
if phns == "-" or phns == "——":
phns = pre_phn
phn_list = phns.split("_")
lyric = "".join(phn_list)
for phn in phn_list:
labels.append(phn)
pre_phn = labels[-1]
lyric_ls.append(lyric)
return lyric_ls, sybs, labels
def create_batch_with_randomized_melody(lyric_ls, sybs, labels, config):
"""
Input:
- answer_text (str), in Chinese character or Japanese character
- model_path (str), loaded pretrained model name
- lang (str), language label jp/zh, input if is not espnet model
Output:
- batch (dict)
{'score': (75, [[0, 0.48057527844210024, 'n@zhi@zh', 66, 'n@zh_i@zh'],
[0.48057527844210024, 0.8049310140914353, 'k@zhe@zh', 57, 'k@zh_e@zh'],
[0.8049310140914353, 1.1905956333296641, 'm@zhei@zh', 64, 'm@zh_ei@zh']]),
'text': 'n@zh i@zh k@zh e@zh m@zh ei@zh'}
"""
tempo = 120
len_note = len(lyric_ls)
notes = []
# midi_range = (57,69)
st = 0
for id_lyric in range(len_note):
pitch = random.randint(57, 69)
period = round(random.uniform(0.1, 0.5), 4)
ed = st + period
note = [st, ed, lyric_ls[id_lyric], pitch, sybs[id_lyric]]
st = ed
notes.append(note)
phns_str = " ".join(labels)
batch = {
"score": (
int(tempo),
notes,
),
"text": phns_str,
}
return batch
def svs_inference(answer_text, svs_model, config, **kwargs):
lyric_ls, sybs, labels = svs_text_preprocessor(
config.model_path, answer_text, config.lang
)
if config.melody_source.startswith("random_generate"):
batch = create_batch_with_randomized_melody(lyric_ls, sybs, labels, config)
elif config.melody_source.startswith("random_select"):
segment_iterator = song_segment_iterator(kwargs["song_db"], kwargs["metadata"])
batch = align_score_and_text(segment_iterator, lyric_ls, sybs, labels, config)
else:
raise NotImplementedError(f"melody source {config.melody_source} not supported")
if config.model_path == "espnet/aceopencpop_svs_visinger2_40singer_pretrain":
sid = np.array([int(config.speaker)])
output_dict = svs_model(batch, sids=sid)
elif config.model_path == "espnet/mixdata_svs_visinger2_spkemb_lang_pretrained":
langs = {
"zh": 2,
"jp": 1,
"en": 2,
}
lid = np.array([langs[config.lang]])
spk_embed = np.load(config.speaker)
output_dict = svs_model(batch, lids=lid, spembs=spk_embed)
else:
raise NotImplementedError(f"Model {config.model_path} not supported")
wav_info = output_dict["wav"].cpu().numpy()
return wav_info
def estimate_sentence_length(query, config, song2note_lengths):
if config.melody_source == "random_select.touhou":
song_name = "touhou"
phrase_length = None
metadata = {"song_name": song_name}
return phrase_length, metadata
if config.melody_source.startswith("random_select"):
song_name = random.choice(list(song2note_lengths.keys()))
phrase_length = song2note_lengths[song_name]
metadata = {"song_name": song_name}
return phrase_length, metadata
else:
raise NotImplementedError(f"melody source {config.melody_source} not supported")
def align_score_and_text(segment_iterator, lyric_ls, sybs, labels, config):
text = []
lyric_idx = 0
notes_info = []
while lyric_idx < len(lyric_ls):
score = next(segment_iterator)
for note_start_time, note_end_time, reference_note_lyric, note_midi in zip(
score["note_start_times"],
score["note_end_times"],
score["note_lyrics"],
score["note_midi"],
):
if reference_note_lyric in ["<AP>", "<SP>"]:
notes_info.append(
[
note_start_time,
note_end_time,
reference_note_lyric.strip("<>"),
note_midi,
reference_note_lyric.strip("<>"),
]
)
text.append(reference_note_lyric.strip("<>"))
elif (
reference_note_lyric in ["-", "——"]
and config.melody_source == "random_select.take_lyric_continuation"
):
notes_info.append(
[
note_start_time,
note_end_time,
reference_note_lyric,
note_midi,
text[-1],
]
)
text.append(text[-1])
else:
notes_info.append(
[
note_start_time,
note_end_time,
lyric_ls[lyric_idx],
note_midi,
sybs[lyric_idx],
]
)
text += sybs[lyric_idx].split("_")
lyric_idx += 1
if lyric_idx >= len(lyric_ls):
break
batch = {
"score": (
score["tempo"], # Assume the tempo is the same for all segments
notes_info,
),
"text": " ".join(text),
}
return batch
def load_list_from_json(json_path):
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
data = [
{
"tempo": d["tempo"],
"note_start_times": [n[0] * (100/d["tempo"]) for n in d["score"]],
"note_end_times": [n[1] * (100/d["tempo"]) for n in d["score"]],
"note_lyrics": ["" for n in d["score"]],
"note_midi": [n[2] for n in d["score"]],
}
for d in data
]
if isinstance(data, list):
return data
else:
raise ValueError("The contents of the json is not list.")
def song_segment_iterator(song_db, metadata):
song_name = metadata["song_name"]
if song_name.startswith("kising_"):
# return a iterator that load from song_name_{001} and increment
segment_id = 1
while f"{song_name}_{segment_id:03d}" in song_db.index:
yield song_db.loc[f"{song_name}_{segment_id:03d}"]
segment_id += 1
elif song_name.startswith("touhou"):
# return a iterator that load from touhou musics
data = load_list_from_json("data/touhou/note_data.json")
while True:
yield random.choice(data)
else:
raise NotImplementedError(f"song name {song_name} not supported")
def load_song_database(config):
from datasets import load_dataset
song_db = load_dataset(
"jhansss/kising_score_segments", cache_dir="cache", split="train"
).to_pandas()
song_db.set_index("segment_id", inplace=True)
if ".take_lyric_continuation" in config.melody_source:
with open("data/song2word_lengths.json", "r") as f:
song2note_lengths = json.load(f)
else:
with open("data/song2note_lengths.json", "r") as f:
song2note_lengths = json.load(f)
return song2note_lengths, song_db
if __name__ == "__main__":
import argparse
import soundfile as sf
# -------- demo code for generate audio from randomly selected song ---------#
config = argparse.Namespace(
model_path="espnet/mixdata_svs_visinger2_spkemb_lang_pretrained",
cache_dir="cache",
device="cuda", # "cpu"
melody_source="random_select.touhou", #"random_generate" "random_select.take_lyric_continuation", "random_select.touhou"
lang="zh",
speaker="resource/singer/singer_embedding_ace-2.npy",
)
# load model
model = svs_warmup(config)
if config.lang == "zh":
answer_text = "天气真好\n空气清新\n气温温和\n风和日丽\n天高气爽\n阳光明媚"
elif config.lang == "jp":
answer_text = "流れてく時の中ででもけだるさが"
else:
print(f"Currently system does not support {config.lang}")
exit(1)
sample_rate = 44100
if config.melody_source.startswith("random_select"):
# load song database: jhansss/kising_score_segments
song2note_lengths, song_db = load_song_database(config)
# get song_name and phrase_length
phrase_length, metadata = estimate_sentence_length(
None, config, song2note_lengths
)
# then, phrase_length info should be added to llm prompt, and get the answer lyrics from llm
additional_kwargs = {"song_db": song_db, "metadata": metadata}
else:
additional_kwargs = {}
wav_info = svs_inference(answer_text, model, config, **additional_kwargs)
# write wav to output_retrieved.wav
save_name = config.melody_source
sf.write(f"{save_name}_{config.lang}.wav", wav_info, samplerate=sample_rate)