|
import tempfile |
|
from pathlib import Path |
|
|
|
import fugashi |
|
import numpy as np |
|
import pyopenjtalk |
|
import torch |
|
import torch.nn.functional as F |
|
from transformers import HubertForCTC, Wav2Vec2Processor |
|
|
|
|
|
class CandidateGenerator: |
|
"""音声と日本語テキストから振り仮名の候補を生成""" |
|
|
|
def __init__(self, device: str = "cpu"): |
|
self.device = device |
|
|
|
dictionary_dir = Path(pyopenjtalk.__file__).parent / "dictionary" |
|
assert dictionary_dir.exists() |
|
|
|
|
|
with open(dictionary_dir / "dicrc", "w") as f: |
|
f.write(r"""cost-factor = 800 |
|
bos-feature = BOS/EOS,*,*,*,*,*,*,*,* |
|
eval-size = 8 |
|
unk-eval-size = 4 |
|
node-format-yomi = %pS%f[7] |
|
unk-format-yomi = %M |
|
eos-format-yomi = \n |
|
node-format-simple = %m\t%F-[0,1,2,3]\n |
|
eos-format-simple = EOS\n |
|
node-format-chasen = %m\t%f[7]\t%f[6]\t%F-[0,1,2,3]\t%f[4]\t%f[5]\n |
|
unk-format-chasen = %m\t%m\t%m\t%F-[0,1,2,3]\t\t\n |
|
eos-format-chasen = EOS\n |
|
node-format-chasen2 = %M\t%f[7]\t%f[6]\t%F-[0,1,2,3]\t%f[4]\t%f[5]\n |
|
unk-format-chasen2 = %M\t%m\t%m\t%F-[0,1,2,3]\t\t\n |
|
eos-format-chasen2 = EOS\n |
|
""") |
|
|
|
with tempfile.NamedTemporaryFile(mode="w+", delete=True) as tmp: |
|
self.tagger = fugashi.GenericTagger(f"-r {tmp.name} -d {dictionary_dir}") |
|
self.tagger_p = fugashi.GenericTagger( |
|
f"-r {tmp.name} -d {dictionary_dir} -p -F 0 -E %pc" |
|
) |
|
|
|
self.hankaku_to_zenkaku_table = str.maketrans( |
|
{chr(i): chr(i + 0xFEE0) for i in range(33, 127)} |
|
) |
|
|
|
HUBERT_MODEL_NAME = "prj-beatrice/japanese-hubert-base-phoneme-ctc-v2" |
|
self.phoneme_hubert = HubertForCTC.from_pretrained(HUBERT_MODEL_NAME).to(device) |
|
self.wav2vec_processor = Wav2Vec2Processor.from_pretrained(HUBERT_MODEL_NAME) |
|
|
|
def generate(self, text: str, audio_16khz: np.ndarray, num: int) -> dict: |
|
"""テキストと音声から情報を抽出""" |
|
results = self.run_mecab(text, num) |
|
self.add_mecab_costs(results) |
|
self.add_phonemes(text, results) |
|
self.add_ctc_loss(audio_16khz, results) |
|
return results |
|
|
|
def run_mecab(self, text: str, num: int) -> dict: |
|
"""MeCab を N-best で実行""" |
|
text = text.translate(self.hankaku_to_zenkaku_table) |
|
nbest: str = self.tagger.nbest(text, num=num) |
|
candidates = [] |
|
for raw_candidate in (nbest + "\n").split("\nEOS\n"): |
|
if not raw_candidate: |
|
continue |
|
features = [line.replace("\t", ",") for line in raw_candidate.splitlines()] |
|
raw_candidate += "\nEOS" |
|
candidate = { |
|
"raw": raw_candidate, |
|
"features": features, |
|
} |
|
candidates.append(candidate) |
|
return {"candidates": candidates} |
|
|
|
def add_mecab_costs(self, results: dict): |
|
"""Mecab コストを追加""" |
|
for candidate in results["candidates"]: |
|
raw_candidate = candidate["raw"] |
|
mecab_cost = int(self.tagger_p.parse(raw_candidate).lstrip("0")) |
|
candidate["mecab_cost"] = mecab_cost |
|
|
|
def add_phonemes(self, text: str, results: dict): |
|
"""音素列等を追加""" |
|
for candidate in results["candidates"]: |
|
njd_result = pyopenjtalk.run_njd_from_mecab(candidate["features"]) |
|
|
|
postprocessed = pyopenjtalk.modify_filler_accent(njd_result) |
|
postprocessed = pyopenjtalk.retreat_acc_nuc(postprocessed) |
|
postprocessed = pyopenjtalk.modify_acc_after_chaining(postprocessed) |
|
postprocessed = pyopenjtalk.process_odori_features(postprocessed) |
|
labels = pyopenjtalk.make_label(postprocessed) |
|
phonemes = list(map(lambda s: s.split("-")[1].split("+")[0], labels[1:-1])) |
|
candidate["njd_result"] = njd_result |
|
candidate["postprocessed"] = postprocessed |
|
candidate["phonemes"] = phonemes |
|
|
|
def add_ctc_loss(self, audio_16khz: np.ndarray, results: dict): |
|
"""CTC loss 等を追加""" |
|
SAMPLING_RATE = 16000 |
|
|
|
audio_16khz = np.concatenate( |
|
[np.zeros(SAMPLING_RATE), audio_16khz, np.zeros(SAMPLING_RATE // 2)] |
|
) |
|
inputs = self.wav2vec_processor( |
|
audio_16khz, |
|
sampling_rate=SAMPLING_RATE, |
|
return_tensors="pt", |
|
).to(self.device) |
|
with torch.no_grad(): |
|
outputs = self.phoneme_hubert(**inputs) |
|
predicted_ids = outputs.logits.argmax(-1) |
|
predicted_phonemes_str: str = self.wav2vec_processor.decode( |
|
predicted_ids[0], spaces_between_special_tokens=True |
|
) |
|
predicted_phonemes: list[str] = predicted_phonemes_str.split() |
|
|
|
assert outputs.logits.ndim == 3, outputs.logits.shape |
|
|
|
|
|
log_probs = F.log_softmax(outputs.logits, dim=-1).transpose(0, 1) |
|
|
|
ctc_input_lengths = torch.tensor([log_probs.size(0)], device=self.device) |
|
|
|
phonemes_candidates = {predicted_phonemes_str: predicted_phonemes} |
|
for candidate in results["candidates"]: |
|
phonemes_str = " ".join(candidate["phonemes"]) |
|
phonemes_candidates[phonemes_str] = candidate["phonemes"] |
|
|
|
phoneme_ids = [] |
|
ctc_target_lengths = [] |
|
for phonemes_str, phonemes in phonemes_candidates.items(): |
|
candidate_phoneme_ids = ( |
|
self.wav2vec_processor.tokenizer.convert_tokens_to_ids(phonemes) |
|
) |
|
phoneme_ids.extend(candidate_phoneme_ids) |
|
ctc_target_lengths.append(len(candidate_phoneme_ids)) |
|
phoneme_ids = torch.tensor(phoneme_ids, device=self.device) |
|
ctc_target_lengths = torch.tensor(ctc_target_lengths, device=self.device) |
|
assert phoneme_ids.ndim == 1, phoneme_ids.shape |
|
|
|
|
|
with torch.backends.cudnn.flags(enabled=False): |
|
loss = F.ctc_loss( |
|
log_probs.expand(-1, len(ctc_target_lengths), -1), |
|
phoneme_ids, |
|
ctc_input_lengths.expand(len(ctc_target_lengths)), |
|
ctc_target_lengths, |
|
blank=self.phoneme_hubert.config.pad_token_id, |
|
reduction="none", |
|
) |
|
|
|
phonemes_str_to_ctc_loss = {} |
|
for phonemes_str, loss in zip(phonemes_candidates, loss.cpu().tolist()): |
|
phonemes_str_to_ctc_loss[phonemes_str] = loss |
|
|
|
for candidate in results["candidates"]: |
|
phonemes_str = " ".join(candidate["phonemes"]) |
|
candidate["ctc_loss"] = phonemes_str_to_ctc_loss[phonemes_str] |
|
results["hubert_logits"] = outputs.logits |
|
results["hubert_prediction"] = { |
|
"phonemes": predicted_phonemes, |
|
"ctc_loss": phonemes_str_to_ctc_loss[predicted_phonemes_str], |
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
import sys |
|
|
|
import librosa |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
candidate_generator = CandidateGenerator(device) |
|
text = sys.argv[1] |
|
audio_file = Path(sys.argv[2]) |
|
audio_16khz, sr = librosa.load(audio_file, sr=16000) |
|
results = candidate_generator.generate(text, audio_16khz, num=10) |
|
|
|
for candidate in results["candidates"]: |
|
print( |
|
f"Cost: {candidate['mecab_cost']}, CTC Loss: {candidate['ctc_loss']:.3f}, Phonemes: {' '.join(candidate['phonemes'])}" |
|
) |
|
print(results) |
|
|
|
|
|
|