jhansss commited on
Commit
4d8ad2d
·
1 Parent(s): 6f349df

Refactor svs_inference and related functions; Bug fixes and code cleanup

Browse files
Files changed (3) hide show
  1. server.py +2 -5
  2. svs_utils.py +68 -126
  3. util.py +12 -6
server.py CHANGED
@@ -86,12 +86,9 @@ async def process_audio(file: UploadFile = File(...)):
86
  f.write(output)
87
 
88
  wav_info = svs_inference(
89
- config.model_path,
90
- svs_model,
91
  output,
92
- lang=config.lang,
93
- random_gen=True,
94
- fs=44100
95
  )
96
  sf.write("tmp/response.wav", wav_info, samplerate=44100)
97
 
 
86
  f.write(output)
87
 
88
  wav_info = svs_inference(
 
 
89
  output,
90
+ svs_model,
91
+ config,
 
92
  )
93
  sf.write("tmp/response.wav", wav_info, samplerate=44100)
94
 
svs_utils.py CHANGED
@@ -1,54 +1,13 @@
1
- from util import (
2
- preprocess_input,
3
- postprocess_phn,
4
- get_tokenizer,
5
- get_pinyin,
6
- )
7
- from espnet_model_zoo.downloader import ModelDownloader
8
- from espnet2.bin.svs_inference import SingingGenerate
9
  import librosa
10
- import torch
11
  import numpy as np
12
- import random
13
- import json
 
14
 
15
- import argparse
16
- import soundfile as sf
17
-
18
- # the code below should be in app.py than svs_utils.py
19
- # espnet_model_dict = {
20
- # "Model①(Chinese)-zh": "espnet/aceopencpop_svs_visinger2_40singer_pretrain",
21
- # "Model②(Multilingual)-zh": "espnet/mixdata_svs_visinger2_spkembed_lang_pretrained",
22
- # "Model②(Multilingual)-jp": "espnet/mixdata_svs_visinger2_spkembed_lang_pretrained",
23
- # }
24
-
25
-
26
- singer_embeddings = {
27
- "espnet/aceopencpop_svs_visinger2_40singer_pretrain": {
28
- "singer1 (male)": 1,
29
- "singer2 (female)": 12,
30
- "singer3 (male)": 23,
31
- "singer4 (female)": 29,
32
- "singer5 (male)": 18,
33
- "singer6 (female)": 8,
34
- "singer7 (male)": 25,
35
- "singer8 (female)": 5,
36
- "singer9 (male)": 10,
37
- "singer10 (female)": 15,
38
- },
39
- "espnet/mixdata_svs_visinger2_spkembed_lang_pretrained": {
40
- "singer1 (male)": "resource/singer/singer_embedding_ace-1.npy",
41
- "singer2 (female)": "resource/singer/singer_embedding_ace-2.npy",
42
- "singer3 (male)": "resource/singer/singer_embedding_ace-3.npy",
43
- "singer4 (female)": "resource/singer/singer_embedding_ace-8.npy",
44
- "singer5 (male)": "resource/singer/singer_embedding_ace-7.npy",
45
- "singer6 (female)": "resource/singer/singer_embedding_itako.npy",
46
- "singer7 (male)": "resource/singer/singer_embedding_ofuton.npy",
47
- "singer8 (female)": "resource/singer/singer_embedding_kising_orange.npy",
48
- "singer9 (male)": "resource/singer/singer_embedding_m4singer_Tenor-1.npy",
49
- "singer10 (female)": "resource/singer/singer_embedding_m4singer_Alto-4.npy",
50
- },
51
- }
52
 
53
 
54
  def svs_warmup(config):
@@ -86,7 +45,7 @@ def svs_text_preprocessor(model_path, texts, lang):
86
  fs = 44100
87
 
88
  if texts is None:
89
- return (fs, np.array([0.0])), "Error: No Text provided!"
90
 
91
  # preprocess
92
  if lang == "zh":
@@ -129,7 +88,7 @@ def svs_text_preprocessor(model_path, texts, lang):
129
  return lyric_ls, sybs, labels
130
 
131
 
132
- def svs_get_batch(model_path, answer_text, lang, random_gen=True):
133
  """
134
  Input:
135
  - answer_text (str), in Chinese character or Japanese character
@@ -144,72 +103,55 @@ def svs_get_batch(model_path, answer_text, lang, random_gen=True):
144
  'text': 'n@zh i@zh k@zh e@zh m@zh ei@zh'}
145
  """
146
  tempo = 120
147
- lyric_ls, sybs, labels = svs_text_preprocessor(model_path, answer_text, lang)
148
  len_note = len(lyric_ls)
149
  notes = []
150
- if random_gen:
151
- # midi_range = (57,69)
152
- st = 0
153
- for id_lyric in range(len_note):
154
- pitch = random.randint(57, 69)
155
- period = round(random.uniform(0.1, 0.5), 4)
156
- ed = st + period
157
- note = [st, ed, lyric_ls[id_lyric], pitch, sybs[id_lyric]]
158
- st = ed
159
- notes.append(note)
160
-
161
- phns_str = " ".join(labels)
162
- batch = {
163
- "score": (
164
- int(tempo),
165
- notes,
166
- ),
167
- "text": phns_str,
168
- }
169
-
170
- # print(batch)
171
  return batch
172
 
173
 
174
- langs = {
175
- "zh": 2,
176
- "jp": 1,
177
- "en": 2,
178
- }
179
-
180
- exist_model = "Null"
181
- svs = None
182
-
183
-
184
- def svs_inference(model_name, model_svs, answer_text, lang, random_gen=True, fs=44100):
185
- batch = svs_get_batch(model_name, answer_text, lang, random_gen=random_gen)
186
-
187
- # Infer
188
- spk = "singer1 (male)"
189
- global exist_model
190
- global svs
191
- svs = model_svs
192
- exist_model = model_name
193
- # if exist_model == "Null" or exist_model != model_name:
194
- # # device = "cpu"
195
- # device = "cuda" if torch.cuda.is_available() else "cpu"
196
- # d = ModelDownloader(cachedir="./cache")
197
- # pretrain_downloaded = d.download_and_unpack(model_name)
198
- # svs = SingingGenerate(
199
- # train_config = pretrain_downloaded["train_config"],
200
- # model_file = pretrain_downloaded["model_file"],
201
- # device = device
202
- # )
203
- # exist_model = model_name
204
- if model_name == "Model①(Chinese)-zh":
205
- sid = np.array([singer_embeddings[model_name][spk]])
206
- output_dict = svs(batch, sids=sid)
207
  else:
208
- lid = np.array([langs[lang]])
209
- spk_embed = np.load("resource/singer/singer_embedding_ace-2.npy")
210
- output_dict = svs(batch, lids=lid, spembs=spk_embed)
211
- wav_info = output_dict["wav"].cpu().numpy()
212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  return wav_info
214
 
215
 
@@ -230,8 +172,6 @@ def singmos_evaluation(predictor, wav_info, fs):
230
 
231
  def estimate_sentence_length(query, config, song2note_lengths):
232
  if config.melody_source.startswith("random_select"):
233
- # random select a song from database, and return its value in the phrase_length column
234
- # return phrase_length column and song name
235
  song_name = random.choice(list(song2note_lengths.keys()))
236
  phrase_length = song2note_lengths[song_name]
237
  metadata = {"song_name": song_name}
@@ -263,7 +203,10 @@ def align_score_and_text(segment_iterator, lyric_ls, sybs, labels, config):
263
  ]
264
  )
265
  text.append(reference_note_lyric.strip("<>"))
266
- elif reference_note_lyric in ["-", "——"] and config.melody_source == "random_select.take_lyric_continuation":
 
 
 
267
  notes_info.append(
268
  [
269
  note_start_time,
@@ -311,6 +254,8 @@ def song_segment_iterator(song_db, metadata):
311
 
312
 
313
  def load_song_database(config):
 
 
314
  song_db = load_dataset(
315
  "jhansss/kising_score_segments", cache_dir="cache", split="train"
316
  ).to_pandas()
@@ -325,6 +270,8 @@ def load_song_database(config):
325
 
326
 
327
  if __name__ == "__main__":
 
 
328
 
329
  # -------- demo code for generate audio from randomly selected song ---------#
330
  config = argparse.Namespace(
@@ -333,6 +280,7 @@ if __name__ == "__main__":
333
  device="cuda", # "cpu"
334
  melody_source="random_generate", # "random_select.take_lyric_continuation"
335
  lang="zh",
 
336
  )
337
 
338
  # load model
@@ -344,28 +292,22 @@ if __name__ == "__main__":
344
 
345
  if config.melody_source.startswith("random_select"):
346
  # load song database: jhansss/kising_score_segments
347
- from datasets import load_dataset
348
  song2note_lengths, song_db = load_song_database(config)
349
 
350
  # get song_name and phrase_length
 
 
 
351
  phrase_length, metadata = estimate_sentence_length(None, config, song2note_lengths)
352
 
353
  # then, phrase_length info should be added to llm prompt, and get the answer lyrics from llm
354
  # e.g. answer_text = "天气真好\n空气清新"
355
- lyric_ls, sybs, labels = svs_text_preprocessor(
356
- config.model_path, answer_text, config.lang
357
- )
358
- segment_iterator = song_segment_iterator(song_db, metadata)
359
- batch = align_score_and_text(segment_iterator, lyric_ls, sybs, labels, config)
360
- singer_embedding = np.load(singer_embeddings[config.model_path]["singer2 (female)"])
361
- lid = np.array([langs[config.lang]])
362
- output_dict = model(batch, lids=lid, spembs=singer_embedding)
363
- wav_info = output_dict["wav"].cpu().numpy()
364
 
365
-
366
- elif config.melody_source.startswith("random_generate"):
367
- wav_info = svs_inference(config.model_path, model, answer_text, lang=config.lang, random_gen=True, fs=sample_rate)
368
 
369
  # write wav to output_retrieved.wav
370
- save_name = config.melody_source.split('.')[0]
371
  sf.write(f"{save_name}.wav", wav_info, samplerate=sample_rate)
 
1
+ import json
2
+ import random
3
+
 
 
 
 
 
4
  import librosa
 
5
  import numpy as np
6
+ import torch
7
+ from espnet2.bin.svs_inference import SingingGenerate
8
+ from espnet_model_zoo.downloader import ModelDownloader
9
 
10
+ from util import get_pinyin, get_tokenizer, postprocess_phn, preprocess_input
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
 
13
  def svs_warmup(config):
 
45
  fs = 44100
46
 
47
  if texts is None:
48
+ raise ValueError("texts is None when calling svs_text_preprocessor")
49
 
50
  # preprocess
51
  if lang == "zh":
 
88
  return lyric_ls, sybs, labels
89
 
90
 
91
+ def create_batch_with_randomized_melody(lyric_ls, sybs, labels, config):
92
  """
93
  Input:
94
  - answer_text (str), in Chinese character or Japanese character
 
103
  'text': 'n@zh i@zh k@zh e@zh m@zh ei@zh'}
104
  """
105
  tempo = 120
 
106
  len_note = len(lyric_ls)
107
  notes = []
108
+ # midi_range = (57,69)
109
+ st = 0
110
+ for id_lyric in range(len_note):
111
+ pitch = random.randint(57, 69)
112
+ period = round(random.uniform(0.1, 0.5), 4)
113
+ ed = st + period
114
+ note = [st, ed, lyric_ls[id_lyric], pitch, sybs[id_lyric]]
115
+ st = ed
116
+ notes.append(note)
117
+ phns_str = " ".join(labels)
118
+ batch = {
119
+ "score": (
120
+ int(tempo),
121
+ notes,
122
+ ),
123
+ "text": phns_str,
124
+ }
 
 
 
 
125
  return batch
126
 
127
 
128
+ def svs_inference(answer_text, svs_model, config, **kwargs):
129
+ lyric_ls, sybs, labels = svs_text_preprocessor(
130
+ config.model_path, answer_text, config.lang
131
+ )
132
+ if config.melody_source.startswith("random_generate"):
133
+ batch = create_batch_with_randomized_melody(lyric_ls, sybs, labels, config)
134
+ elif config.melody_source.startswith("random_select"):
135
+ segment_iterator = song_segment_iterator(kwargs["song_db"], kwargs["metadata"])
136
+ batch = align_score_and_text(segment_iterator, lyric_ls, sybs, labels, config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  else:
138
+ raise NotImplementedError(f"melody source {config.melody_source} not supported")
 
 
 
139
 
140
+ if config.model_path == "espnet/aceopencpop_svs_visinger2_40singer_pretrain":
141
+ sid = np.array([config.speaker])
142
+ output_dict = svs_model(batch, sids=sid)
143
+ elif config.model_path == "espnet/mixdata_svs_visinger2_spkembed_lang_pretrained":
144
+ langs = {
145
+ "zh": 2,
146
+ "jp": 1,
147
+ "en": 2,
148
+ }
149
+ lid = np.array([langs[config.lang]])
150
+ spk_embed = np.load(config.speaker)
151
+ output_dict = svs_model(batch, lids=lid, spembs=spk_embed)
152
+ else:
153
+ raise NotImplementedError(f"Model {config.model_path} not supported")
154
+ wav_info = output_dict["wav"].cpu().numpy()
155
  return wav_info
156
 
157
 
 
172
 
173
  def estimate_sentence_length(query, config, song2note_lengths):
174
  if config.melody_source.startswith("random_select"):
 
 
175
  song_name = random.choice(list(song2note_lengths.keys()))
176
  phrase_length = song2note_lengths[song_name]
177
  metadata = {"song_name": song_name}
 
203
  ]
204
  )
205
  text.append(reference_note_lyric.strip("<>"))
206
+ elif (
207
+ reference_note_lyric in ["-", "——"]
208
+ and config.melody_source == "random_select.take_lyric_continuation"
209
+ ):
210
  notes_info.append(
211
  [
212
  note_start_time,
 
254
 
255
 
256
  def load_song_database(config):
257
+ from datasets import load_dataset
258
+
259
  song_db = load_dataset(
260
  "jhansss/kising_score_segments", cache_dir="cache", split="train"
261
  ).to_pandas()
 
270
 
271
 
272
  if __name__ == "__main__":
273
+ import argparse
274
+ import soundfile as sf
275
 
276
  # -------- demo code for generate audio from randomly selected song ---------#
277
  config = argparse.Namespace(
 
280
  device="cuda", # "cpu"
281
  melody_source="random_generate", # "random_select.take_lyric_continuation"
282
  lang="zh",
283
+ speaker="resource/singer/singer_embedding_ace-2.npy",
284
  )
285
 
286
  # load model
 
292
 
293
  if config.melody_source.startswith("random_select"):
294
  # load song database: jhansss/kising_score_segments
 
295
  song2note_lengths, song_db = load_song_database(config)
296
 
297
  # get song_name and phrase_length
298
+ phrase_length, metadata = estimate_sentence_length(
299
+ None, config, song2note_lengths
300
+ )
301
  phrase_length, metadata = estimate_sentence_length(None, config, song2note_lengths)
302
 
303
  # then, phrase_length info should be added to llm prompt, and get the answer lyrics from llm
304
  # e.g. answer_text = "天气真好\n空气清新"
305
+ additional_kwargs = {"song_db": song_db, "metadata": metadata}
306
+ else:
307
+ additional_kwargs = {}
 
 
 
 
 
 
308
 
309
+ wav_info = svs_inference(answer_text, model, config, **additional_kwargs)
 
 
310
 
311
  # write wav to output_retrieved.wav
312
+ save_name = config.melody_source
313
  sf.write(f"{save_name}.wav", wav_info, samplerate=sample_rate)
util.py CHANGED
@@ -21,6 +21,7 @@ def postprocess_phn(phns, model_name, lang):
21
 
22
 
23
  def pyopenjtalk_g2p(text) -> List[str]:
 
24
  with warnings.catch_warnings(record=True) as w:
25
  warnings.simplefilter("always")
26
  # phones is a str object separated by space
@@ -53,20 +54,25 @@ def split_pinyin_py(pinyin: str) -> tuple[str]:
53
 
54
 
55
  def get_tokenizer(model, lang):
56
- if lang == "zh":
57
- if "Chinese" in model:
58
- print("hello")
59
  return lambda text: split_pinyin_py(text)
60
  else:
 
 
 
61
  with open(os.path.join("resource/all_plans.json"), "r") as f:
62
  all_plan_dict = json.load(f)
63
  for plan in all_plan_dict["plans"]:
64
  if plan["language"] == "zh":
65
  zh_plan = plan
66
  return lambda text: split_pinyin_ace(text, zh_plan)
67
- elif lang == "jp":
68
- import pyopenjtalk
69
- return pyopenjtalk_g2p
 
 
 
70
 
71
 
72
  def get_pinyin(texts):
 
21
 
22
 
23
  def pyopenjtalk_g2p(text) -> List[str]:
24
+ import pyopenjtalk
25
  with warnings.catch_warnings(record=True) as w:
26
  warnings.simplefilter("always")
27
  # phones is a str object separated by space
 
54
 
55
 
56
  def get_tokenizer(model, lang):
57
+ if model == "espnet/aceopencpop_svs_visinger2_40singer_pretrain":
58
+ if lang == "zh":
 
59
  return lambda text: split_pinyin_py(text)
60
  else:
61
+ raise ValueError(f"Only support Chinese language for {model}")
62
+ elif model == "espnet/mixdata_svs_visinger2_spkembed_lang_pretrained":
63
+ if lang == "zh":
64
  with open(os.path.join("resource/all_plans.json"), "r") as f:
65
  all_plan_dict = json.load(f)
66
  for plan in all_plan_dict["plans"]:
67
  if plan["language"] == "zh":
68
  zh_plan = plan
69
  return lambda text: split_pinyin_ace(text, zh_plan)
70
+ elif lang == "jp":
71
+ return pyopenjtalk_g2p
72
+ else:
73
+ raise ValueError(f"Only support Chinese and Japanese language for {model}")
74
+ else:
75
+ raise ValueError(f"Only support espnet/aceopencpop_svs_visinger2_40singer_pretrain and espnet/mixdata_svs_visinger2_spkembed_lang_pretrained for now")
76
 
77
 
78
  def get_pinyin(texts):