Spaces:
Runtime error
Runtime error
| import torch | |
| from wav2vecasr.models import MultiTaskWav2Vec2 | |
| from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM, \ | |
| Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor | |
| import pyctcdecode | |
| import json | |
| import re | |
| from sys import platform | |
| class PhonemeASRModel: | |
| def get_l2_phoneme_sequence(self, audio): | |
| """ | |
| :param audio: audio sampled at 16k sampling rate with torchaudio | |
| :type audio: array | |
| :return: predicted phonemes for L2 speaker | |
| :rtype: array | |
| """ | |
| pass | |
| def standardise_g2p_phoneme_sequence(self, phones): | |
| """ | |
| To facilitate mispronounciation detection | |
| :param phones: native speaker phones predicted by G2P model | |
| :type phones: array | |
| :return: standardised native speaker phoneme sequence that aligns with phoneme classes by the model | |
| :rtype: array | |
| """ | |
| pass | |
| def standardise_l2_artic_groundtruth_phoneme_sequence(self, phones): | |
| """ | |
| To facilitate testing | |
| :param phones: native speaker phones as annotated in l2 artic | |
| :type phones: array | |
| :return: standardised native speaker phoneme sequence that aligns with phoneme classes by the model | |
| :rtype: array | |
| """ | |
| pass | |
| class MultitaskPhonemeASRModel(PhonemeASRModel): | |
| def __init__(self, model_path, best_model_vocab_path, device): | |
| self.device = device | |
| tokenizer = Wav2Vec2CTCTokenizer(best_model_vocab_path, unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|") | |
| feature_extractor = Wav2Vec2FeatureExtractor( | |
| feature_size=1, | |
| sampling_rate=16000, | |
| padding_value=0.0, | |
| do_normalize=True, | |
| return_attention_mask=False, | |
| ) | |
| processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) | |
| wav2vec2_backbone = Wav2Vec2ForCTC.from_pretrained( | |
| pretrained_model_name_or_path="facebook/wav2vec2-xls-r-300m", | |
| ignore_mismatched_sizes=True, | |
| ctc_loss_reduction="mean", | |
| pad_token_id=processor.tokenizer.pad_token_id, | |
| vocab_size=len(processor.tokenizer), | |
| output_hidden_states=True, | |
| ) | |
| wav2vec2_backbone = wav2vec2_backbone.to(device) | |
| model = MultiTaskWav2Vec2( | |
| wav2vec2_backbone=wav2vec2_backbone, | |
| backbone_hidden_size=1024, | |
| projection_hidden_size=256, | |
| num_accent_class=3, | |
| ) | |
| model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
| model.to(device) | |
| model.eval() | |
| self.multitask_model = model | |
| self.processor = processor | |
| def get_l2_phoneme_sequence(self, audio): | |
| audio = audio.unsqueeze(0) | |
| audio = self.processor(audio, sampling_rate=16000).input_values[0] | |
| audio = torch.tensor(audio, device=self.device) | |
| with torch.no_grad(): | |
| _, lm_logits, _, _ = self.multitask_model(audio) | |
| lm_preds = torch.argmax(lm_logits, dim=-1) | |
| # Decode output results | |
| pred_decoded = self.processor.batch_decode(lm_preds) | |
| pred_phones = pred_decoded[0].split(" ") | |
| # remove sil and sp | |
| pred_phones = [phone for phone in pred_phones if phone != "sil" and phone != "sp"] | |
| return pred_phones | |
| def standardise_g2p_phoneme_sequence(self, phones): | |
| return phones | |
| def standardise_l2_artic_groundtruth_phoneme_sequence(self, phones): | |
| return phones | |
| class Wav2Vec2PhonemeASRModel(PhonemeASRModel): | |
| """ | |
| Uses greedy decoding | |
| """ | |
| def __init__(self, model_path, processor_path): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model = Wav2Vec2ForCTC.from_pretrained(model_path).to(self.device) | |
| self.processor = Wav2Vec2Processor.from_pretrained(processor_path) | |
| def get_l2_phoneme_sequence(self, audio): | |
| input_dict = self.processor(audio, sampling_rate=16000, return_tensors="pt", padding=True) | |
| logits = self.model(input_dict.input_values.to(self.device)).logits | |
| pred_ids = torch.argmax(logits, dim=-1)[0] | |
| pred_phones = [phoneme for phoneme in self.processor.batch_decode(pred_ids) if phoneme != ""] | |
| return pred_phones | |
| def standardise_g2p_phoneme_sequence(self, phones): | |
| return phones | |
| def standardise_l2_artic_groundtruth_phoneme_sequence(self, phones): | |
| return [re.sub(r'\d', "", phone_str) for phone_str in phones] | |
| # TODO debug on linux because KenLM is not supported on Windows | |
| class Wav2Vec2OptimisedPhonemeASRModel(PhonemeASRModel): | |
| """ | |
| Uses beam search and a LM for decoding | |
| """ | |
| def __init__(self, model_path, vocab_json_path, kenlm_model_path): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| f = open(vocab_json_path) | |
| vocab_dict = json.load(f) | |
| tokenizer = Wav2Vec2CTCTokenizer(vocab_json_path, unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|") | |
| feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, | |
| do_normalize=True, return_attention_mask=False) | |
| labels = list(vocab_dict.keys()) | |
| # beam search | |
| decoder = pyctcdecode.decoder.build_ctcdecoder(labels) | |
| if (platform == "linux" or platform == "linux2") and kenlm_model_path: | |
| # beam search + LM | |
| decoder = pyctcdecode.decoder.build_ctcdecoder(labels, kenlm_model_path=kenlm_model_path) | |
| self.model = Wav2Vec2ForCTC.from_pretrained(model_path).to(self.device) | |
| self.processor = Wav2Vec2ProcessorWithLM(feature_extractor=feature_extractor, tokenizer=tokenizer, decoder=decoder) | |
| def get_l2_phoneme_sequence(self, audio): | |
| input_dict = self.processor(audio, sampling_rate=16000, return_tensors="pt", padding=True) | |
| logits = self.model(input_dict.input_values.to(self.device)).logits.cpu().detach() | |
| normalised_logits = torch.nn.Softmax(dim=2)(logits) | |
| normalised_logits = normalised_logits.numpy()[0] | |
| output = self.processor.decode(normalised_logits) | |
| pred_phones = output.text.split(" ") | |
| return pred_phones | |
| def standardise_g2p_phoneme_sequence(self, phones): | |
| return phones | |
| def standardise_l2_artic_groundtruth_phoneme_sequence(self, phones): | |
| return [re.sub(r'\d', "", phone_str) for phone_str in phones] | |