Spaces:
Running
Running
| from typing import Callable | |
| import numpy as np | |
| from modules.utils.g2p import ( | |
| kana_to_phonemes_openjtalk, | |
| pinyin_to_phonemes_ace, | |
| pinyin_to_phonemes_opencpop, | |
| ) | |
| from .base import AbstractSVSModel | |
| from .registry import register_svs_model | |
| class ESPNetSVS(AbstractSVSModel): | |
| def __init__(self, model_id: str, device="auto", cache_dir="cache", **kwargs): | |
| from espnet2.bin.svs_inference import SingingGenerate | |
| from espnet_model_zoo.downloader import ModelDownloader | |
| if device == "auto": | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.device = device | |
| downloaded = ModelDownloader(cache_dir).download_and_unpack(model_id) | |
| self.model = SingingGenerate( | |
| train_config=downloaded["train_config"], | |
| model_file=downloaded["model_file"], | |
| device=self.device, | |
| ) | |
| self.model_id = model_id | |
| self.output_sample_rate = self.model.fs | |
| self.phoneme_mappers = self._build_phoneme_mappers() | |
| def _build_phoneme_mappers(self) -> dict[str, Callable[[str], list[str]]]: | |
| if self.model_id == "espnet/aceopencpop_svs_visinger2_40singer_pretrain": | |
| phoneme_mappers = { | |
| "mandarin": pinyin_to_phonemes_opencpop, | |
| } | |
| elif self.model_id == "espnet/mixdata_svs_visinger2_spkemb_lang_pretrained": | |
| def mandarin_mapper(pinyin: str) -> list[str]: | |
| phns = pinyin_to_phonemes_ace(pinyin) | |
| return [phn + "@zh" for phn in phns] | |
| def japanese_mapper(kana: str) -> list[str]: | |
| phones = kana_to_phonemes_openjtalk(kana) | |
| return [phn + "@jp" for phn in phones] | |
| phoneme_mappers = { | |
| "mandarin": mandarin_mapper, | |
| "japanese": japanese_mapper, | |
| } | |
| else: | |
| phoneme_mappers = {} | |
| return phoneme_mappers | |
| def _preprocess(self, score: list[tuple[float, float, str, int] | tuple[float, float, str, float]], language: str): | |
| if language not in self.phoneme_mappers: | |
| raise ValueError(f"Unsupported language: {language} for {self.model_id}") | |
| phoneme_mapper = self.phoneme_mappers[language] | |
| # text to phoneme | |
| notes = [] | |
| phns = [] | |
| pre_phn = None | |
| for st, ed, text, pitch in score: | |
| assert text not in [ | |
| "<AP>", | |
| "<SP>", | |
| ], f"Proccessed score segments should not contain <AP> or <SP>. {score}" # TODO: remove in PR, only for debug | |
| if text == "AP" or text == "SP": | |
| lyric_units = [text] | |
| phn_units = [text] | |
| elif text == "-" or text == "——": | |
| lyric_units = [text] | |
| if pre_phn is None: | |
| raise ValueError( | |
| f"Text `{text}` cannot be recognized by {self.model_id}. Lyrics cannot start with a lyric continuation symbol `-` or `——`" | |
| ) | |
| phn_units = [pre_phn] | |
| else: | |
| try: | |
| lyric_units = phoneme_mapper(text) | |
| except ValueError as e: | |
| raise ValueError( | |
| f"Text `{text}` cannot be recognized by {self.model_id}" | |
| ) from e | |
| phn_units = lyric_units | |
| notes.append((st, ed, "".join(lyric_units), pitch, "_".join(phn_units))) | |
| phns.extend(phn_units) | |
| pre_phn = phn_units[-1] | |
| batch = { | |
| "score": ( | |
| 120, # does not affect svs result, as note durations are in time unit | |
| notes, | |
| ), | |
| "text": " ".join(phns), | |
| } | |
| return batch | |
| def synthesize( | |
| self, score: list[tuple[float, float, str, float] | tuple[float, float, str, int]], language: str, speaker: str, **kwargs | |
| ): | |
| batch = self._preprocess(score, language) | |
| if self.model_id == "espnet/aceopencpop_svs_visinger2_40singer_pretrain": | |
| sid = np.array([int(speaker)]) | |
| output_dict = self.model(batch, sids=sid) | |
| elif self.model_id == "espnet/mixdata_svs_visinger2_spkemb_lang_pretrained": | |
| langs = { | |
| "mandarin": 2, | |
| "japanese": 1, | |
| } | |
| if language not in langs: | |
| raise ValueError( | |
| f"Unsupported language: {language} for {self.model_id}" | |
| ) | |
| lid = np.array([langs[language]]) | |
| spk_embed = np.load(speaker) | |
| output_dict = self.model(batch, lids=lid, spembs=spk_embed) | |
| else: | |
| raise NotImplementedError(f"Model {self.model_id} not supported") | |
| wav_info = output_dict["wav"].cpu().numpy() | |
| return wav_info, self.output_sample_rate | |