Project Beatrice
Initial commit
5124d5d
import tempfile
from pathlib import Path
import fugashi
import numpy as np
import pyopenjtalk
import torch
import torch.nn.functional as F
from transformers import HubertForCTC, Wav2Vec2Processor
class CandidateGenerator:
"""音声と日本語テキストから振り仮名の候補を生成"""
def __init__(self, device: str = "cpu"):
self.device = device
dictionary_dir = Path(pyopenjtalk.__file__).parent / "dictionary"
assert dictionary_dir.exists()
# openjtalk には設定がハードコードされていて dicrc ファイルが存在しない
with open(dictionary_dir / "dicrc", "w") as f:
f.write(r"""cost-factor = 800
bos-feature = BOS/EOS,*,*,*,*,*,*,*,*
eval-size = 8
unk-eval-size = 4
node-format-yomi = %pS%f[7]
unk-format-yomi = %M
eos-format-yomi = \n
node-format-simple = %m\t%F-[0,1,2,3]\n
eos-format-simple = EOS\n
node-format-chasen = %m\t%f[7]\t%f[6]\t%F-[0,1,2,3]\t%f[4]\t%f[5]\n
unk-format-chasen = %m\t%m\t%m\t%F-[0,1,2,3]\t\t\n
eos-format-chasen = EOS\n
node-format-chasen2 = %M\t%f[7]\t%f[6]\t%F-[0,1,2,3]\t%f[4]\t%f[5]\n
unk-format-chasen2 = %M\t%m\t%m\t%F-[0,1,2,3]\t\t\n
eos-format-chasen2 = EOS\n
""")
with tempfile.NamedTemporaryFile(mode="w+", delete=True) as tmp:
self.tagger = fugashi.GenericTagger(f"-r {tmp.name} -d {dictionary_dir}")
self.tagger_p = fugashi.GenericTagger(
f"-r {tmp.name} -d {dictionary_dir} -p -F 0 -E %pc"
)
self.hankaku_to_zenkaku_table = str.maketrans(
{chr(i): chr(i + 0xFEE0) for i in range(33, 127)}
)
HUBERT_MODEL_NAME = "prj-beatrice/japanese-hubert-base-phoneme-ctc-v2"
self.phoneme_hubert = HubertForCTC.from_pretrained(HUBERT_MODEL_NAME).to(device)
self.wav2vec_processor = Wav2Vec2Processor.from_pretrained(HUBERT_MODEL_NAME)
def generate(self, text: str, audio_16khz: np.ndarray, num: int) -> dict:
"""テキストと音声から情報を抽出"""
results = self.run_mecab(text, num)
self.add_mecab_costs(results)
self.add_phonemes(text, results)
self.add_ctc_loss(audio_16khz, results)
return results
def run_mecab(self, text: str, num: int) -> dict:
"""MeCab を N-best で実行"""
text = text.translate(self.hankaku_to_zenkaku_table)
nbest: str = self.tagger.nbest(text, num=num)
candidates = []
for raw_candidate in (nbest + "\n").split("\nEOS\n"):
if not raw_candidate:
continue
features = [line.replace("\t", ",") for line in raw_candidate.splitlines()]
raw_candidate += "\nEOS"
candidate = {
"raw": raw_candidate,
"features": features,
}
candidates.append(candidate)
return {"candidates": candidates}
def add_mecab_costs(self, results: dict):
"""Mecab コストを追加"""
for candidate in results["candidates"]:
raw_candidate = candidate["raw"]
mecab_cost = int(self.tagger_p.parse(raw_candidate).lstrip("0"))
candidate["mecab_cost"] = mecab_cost
def add_phonemes(self, text: str, results: dict):
"""音素列等を追加"""
for candidate in results["candidates"]:
njd_result = pyopenjtalk.run_njd_from_mecab(candidate["features"])
# modify_kanji_yomi は使わない
postprocessed = pyopenjtalk.modify_filler_accent(njd_result)
postprocessed = pyopenjtalk.retreat_acc_nuc(postprocessed)
postprocessed = pyopenjtalk.modify_acc_after_chaining(postprocessed)
postprocessed = pyopenjtalk.process_odori_features(postprocessed)
labels = pyopenjtalk.make_label(postprocessed)
phonemes = list(map(lambda s: s.split("-")[1].split("+")[0], labels[1:-1]))
candidate["njd_result"] = njd_result
candidate["postprocessed"] = postprocessed
candidate["phonemes"] = phonemes
def add_ctc_loss(self, audio_16khz: np.ndarray, results: dict):
"""CTC loss 等を追加"""
SAMPLING_RATE = 16000
# ReazonSpeech の音声認識モデルに倣ってパディングする
audio_16khz = np.concatenate(
[np.zeros(SAMPLING_RATE), audio_16khz, np.zeros(SAMPLING_RATE // 2)]
)
inputs = self.wav2vec_processor(
audio_16khz,
sampling_rate=SAMPLING_RATE,
return_tensors="pt",
).to(self.device)
with torch.no_grad():
outputs = self.phoneme_hubert(**inputs)
predicted_ids = outputs.logits.argmax(-1)
predicted_phonemes_str: str = self.wav2vec_processor.decode(
predicted_ids[0], spaces_between_special_tokens=True
)
predicted_phonemes: list[str] = predicted_phonemes_str.split()
assert outputs.logits.ndim == 3, outputs.logits.shape
# [length, 1, vocab_size]
log_probs = F.log_softmax(outputs.logits, dim=-1).transpose(0, 1)
# [1]
ctc_input_lengths = torch.tensor([log_probs.size(0)], device=self.device)
phonemes_candidates = {predicted_phonemes_str: predicted_phonemes}
for candidate in results["candidates"]:
phonemes_str = " ".join(candidate["phonemes"])
phonemes_candidates[phonemes_str] = candidate["phonemes"]
phoneme_ids = []
ctc_target_lengths = []
for phonemes_str, phonemes in phonemes_candidates.items():
candidate_phoneme_ids = (
self.wav2vec_processor.tokenizer.convert_tokens_to_ids(phonemes)
)
phoneme_ids.extend(candidate_phoneme_ids)
ctc_target_lengths.append(len(candidate_phoneme_ids))
phoneme_ids = torch.tensor(phoneme_ids, device=self.device)
ctc_target_lengths = torch.tensor(ctc_target_lengths, device=self.device)
assert phoneme_ids.ndim == 1, phoneme_ids.shape
# Transformers が cudnn を無効にしていたのでそれに倣う
with torch.backends.cudnn.flags(enabled=False):
loss = F.ctc_loss(
log_probs.expand(-1, len(ctc_target_lengths), -1),
phoneme_ids,
ctc_input_lengths.expand(len(ctc_target_lengths)),
ctc_target_lengths,
blank=self.phoneme_hubert.config.pad_token_id,
reduction="none",
)
phonemes_str_to_ctc_loss = {}
for phonemes_str, loss in zip(phonemes_candidates, loss.cpu().tolist()):
phonemes_str_to_ctc_loss[phonemes_str] = loss
for candidate in results["candidates"]:
phonemes_str = " ".join(candidate["phonemes"])
candidate["ctc_loss"] = phonemes_str_to_ctc_loss[phonemes_str]
results["hubert_logits"] = outputs.logits
results["hubert_prediction"] = {
"phonemes": predicted_phonemes,
"ctc_loss": phonemes_str_to_ctc_loss[predicted_phonemes_str],
}
if __name__ == "__main__":
import sys
import librosa
device = "cuda" if torch.cuda.is_available() else "cpu"
candidate_generator = CandidateGenerator(device)
text = sys.argv[1]
audio_file = Path(sys.argv[2])
audio_16khz, sr = librosa.load(audio_file, sr=16000)
results = candidate_generator.generate(text, audio_16khz, num=10)
for candidate in results["candidates"]:
print(
f"Cost: {candidate['mecab_cost']}, CTC Loss: {candidate['ctc_loss']:.3f}, Phonemes: {' '.join(candidate['phonemes'])}"
)
print(results)
# uv run src/__init__.py "テキスト" path/to/audio.wav