Spaces:
Sleeping
Sleeping
| # This module is from [WeNet](https://github.com/wenet-e2e/wenet). | |
| # ## Citations | |
| # ```bibtex | |
| # @inproceedings{yao2021wenet, | |
| # title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit}, | |
| # author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin}, | |
| # booktitle={Proc. Interspeech}, | |
| # year={2021}, | |
| # address={Brno, Czech Republic }, | |
| # organization={IEEE} | |
| # } | |
| # @article{zhang2022wenet, | |
| # title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit}, | |
| # author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei}, | |
| # journal={arXiv preprint arXiv:2203.15455}, | |
| # year={2022} | |
| # } | |
| # | |
| from typing import Dict, Optional, Tuple | |
| import torch | |
| from modules.wenet_extractor.cif.predictor import MAELoss | |
| from modules.wenet_extractor.paraformer.search.beam_search import Hypothesis | |
| from modules.wenet_extractor.transformer.asr_model import ASRModel | |
| from modules.wenet_extractor.transformer.ctc import CTC | |
| from modules.wenet_extractor.transformer.decoder import TransformerDecoder | |
| from modules.wenet_extractor.transformer.encoder import TransformerEncoder | |
| from modules.wenet_extractor.utils.common import IGNORE_ID, add_sos_eos, th_accuracy | |
| from modules.wenet_extractor.utils.mask import make_pad_mask | |
| class Paraformer(ASRModel): | |
| """Paraformer: Fast and Accurate Parallel Transformer for | |
| Non-autoregressive End-to-End Speech Recognition | |
| see https://arxiv.org/pdf/2206.08317.pdf | |
| """ | |
| def __init__( | |
| self, | |
| vocab_size: int, | |
| encoder: TransformerEncoder, | |
| decoder: TransformerDecoder, | |
| ctc: CTC, | |
| predictor, | |
| ctc_weight: float = 0.5, | |
| predictor_weight: float = 1.0, | |
| predictor_bias: int = 0, | |
| ignore_id: int = IGNORE_ID, | |
| reverse_weight: float = 0.0, | |
| lsm_weight: float = 0.0, | |
| length_normalized_loss: bool = False, | |
| ): | |
| assert 0.0 <= ctc_weight <= 1.0, ctc_weight | |
| assert 0.0 <= predictor_weight <= 1.0, predictor_weight | |
| super().__init__( | |
| vocab_size, | |
| encoder, | |
| decoder, | |
| ctc, | |
| ctc_weight, | |
| ignore_id, | |
| reverse_weight, | |
| lsm_weight, | |
| length_normalized_loss, | |
| ) | |
| self.predictor = predictor | |
| self.predictor_weight = predictor_weight | |
| self.predictor_bias = predictor_bias | |
| self.criterion_pre = MAELoss(normalize_length=length_normalized_loss) | |
| def forward( | |
| self, | |
| speech: torch.Tensor, | |
| speech_lengths: torch.Tensor, | |
| text: torch.Tensor, | |
| text_lengths: torch.Tensor, | |
| ) -> Dict[str, Optional[torch.Tensor]]: | |
| """Frontend + Encoder + Decoder + Calc loss | |
| Args: | |
| speech: (Batch, Length, ...) | |
| speech_lengths: (Batch, ) | |
| text: (Batch, Length) | |
| text_lengths: (Batch,) | |
| """ | |
| assert text_lengths.dim() == 1, text_lengths.shape | |
| # Check that batch_size is unified | |
| assert ( | |
| speech.shape[0] | |
| == speech_lengths.shape[0] | |
| == text.shape[0] | |
| == text_lengths.shape[0] | |
| ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) | |
| # 1. Encoder | |
| encoder_out, encoder_mask = self.encoder(speech, speech_lengths) | |
| encoder_out_lens = encoder_mask.squeeze(1).sum(1) | |
| # 2a. Attention-decoder branch | |
| if self.ctc_weight != 1.0: | |
| loss_att, acc_att, loss_pre = self._calc_att_loss( | |
| encoder_out, encoder_mask, text, text_lengths | |
| ) | |
| else: | |
| # loss_att = None | |
| # loss_pre = None | |
| loss_att: torch.Tensor = torch.tensor(0) | |
| loss_pre: torch.Tensor = torch.tensor(0) | |
| # 2b. CTC branch | |
| if self.ctc_weight != 0.0: | |
| loss_ctc = self.ctc(encoder_out, encoder_out_lens, text, text_lengths) | |
| else: | |
| loss_ctc = None | |
| if loss_ctc is None: | |
| loss = loss_att + self.predictor_weight * loss_pre | |
| # elif loss_att is None: | |
| elif loss_att == torch.tensor(0): | |
| loss = loss_ctc | |
| else: | |
| loss = ( | |
| self.ctc_weight * loss_ctc | |
| + (1 - self.ctc_weight) * loss_att | |
| + self.predictor_weight * loss_pre | |
| ) | |
| return { | |
| "loss": loss, | |
| "loss_att": loss_att, | |
| "loss_ctc": loss_ctc, | |
| "loss_pre": loss_pre, | |
| } | |
| def _calc_att_loss( | |
| self, | |
| encoder_out: torch.Tensor, | |
| encoder_mask: torch.Tensor, | |
| ys_pad: torch.Tensor, | |
| ys_pad_lens: torch.Tensor, | |
| ) -> Tuple[torch.Tensor, float, torch.Tensor]: | |
| if self.predictor_bias == 1: | |
| _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) | |
| ys_pad_lens = ys_pad_lens + self.predictor_bias | |
| pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor( | |
| encoder_out, ys_pad, encoder_mask, ignore_id=self.ignore_id | |
| ) | |
| # 1. Forward decoder | |
| decoder_out, _, _ = self.decoder( | |
| encoder_out, encoder_mask, pre_acoustic_embeds, ys_pad_lens | |
| ) | |
| # 2. Compute attention loss | |
| loss_att = self.criterion_att(decoder_out, ys_pad) | |
| acc_att = th_accuracy( | |
| decoder_out.view(-1, self.vocab_size), | |
| ys_pad, | |
| ignore_label=self.ignore_id, | |
| ) | |
| loss_pre: torch.Tensor = self.criterion_pre( | |
| ys_pad_lens.type_as(pre_token_length), pre_token_length | |
| ) | |
| return loss_att, acc_att, loss_pre | |
| def calc_predictor(self, encoder_out, encoder_mask): | |
| encoder_mask = ( | |
| ~make_pad_mask(encoder_mask, max_len=encoder_out.size(1))[:, None, :] | |
| ).to(encoder_out.device) | |
| pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor( | |
| encoder_out, None, encoder_mask, ignore_id=self.ignore_id | |
| ) | |
| return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index | |
| def cal_decoder_with_predictor( | |
| self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens | |
| ): | |
| decoder_out, _, _ = self.decoder( | |
| encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens | |
| ) | |
| decoder_out = torch.log_softmax(decoder_out, dim=-1) | |
| return decoder_out, ys_pad_lens | |
| def recognize(self): | |
| raise NotImplementedError | |
| def paraformer_greedy_search( | |
| self, | |
| speech: torch.Tensor, | |
| speech_lengths: torch.Tensor, | |
| decoding_chunk_size: int = -1, | |
| num_decoding_left_chunks: int = -1, | |
| simulate_streaming: bool = False, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Apply beam search on attention decoder | |
| Args: | |
| speech (torch.Tensor): (batch, max_len, feat_dim) | |
| speech_length (torch.Tensor): (batch, ) | |
| decoding_chunk_size (int): decoding chunk for dynamic chunk | |
| trained model. | |
| <0: for decoding, use full chunk. | |
| >0: for decoding, use fixed chunk size as set. | |
| 0: used for training, it's prohibited here | |
| simulate_streaming (bool): whether do encoder forward in a | |
| streaming fashion | |
| Returns: | |
| torch.Tensor: decoding result, (batch, max_result_len) | |
| """ | |
| assert speech.shape[0] == speech_lengths.shape[0] | |
| assert decoding_chunk_size != 0 | |
| device = speech.device | |
| batch_size = speech.shape[0] | |
| # Let's assume B = batch_size and N = beam_size | |
| # 1. Encoder | |
| encoder_out, encoder_mask = self._forward_encoder( | |
| speech, | |
| speech_lengths, | |
| decoding_chunk_size, | |
| num_decoding_left_chunks, | |
| simulate_streaming, | |
| ) # (B, maxlen, encoder_dim) | |
| encoder_out_lens = encoder_mask.squeeze(1).sum(1) | |
| # 2. Predictor | |
| predictor_outs = self.calc_predictor(encoder_out, encoder_mask) | |
| pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = ( | |
| predictor_outs[0], | |
| predictor_outs[1], | |
| predictor_outs[2], | |
| predictor_outs[3], | |
| ) | |
| pre_token_length = pre_token_length.round().long() | |
| if torch.max(pre_token_length) < 1: | |
| return torch.tensor([]), torch.tensor([]) | |
| # 2. Decoder forward | |
| decoder_outs = self.cal_decoder_with_predictor( | |
| encoder_out, encoder_out_lens, pre_acoustic_embeds, pre_token_length | |
| ) | |
| decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] | |
| hyps = [] | |
| b, n, d = decoder_out.size() | |
| for i in range(b): | |
| x = encoder_out[i, : encoder_out_lens[i], :] | |
| am_scores = decoder_out[i, : pre_token_length[i], :] | |
| yseq = am_scores.argmax(dim=-1) | |
| score = am_scores.max(dim=-1)[0] | |
| score = torch.sum(score, dim=-1) | |
| # pad with mask tokens to ensure compatibility with sos/eos tokens | |
| yseq = torch.tensor( | |
| [self.sos] + yseq.tolist() + [self.eos], device=yseq.device | |
| ) | |
| nbest_hyps = [Hypothesis(yseq=yseq, score=score)] | |
| for hyp in nbest_hyps: | |
| assert isinstance(hyp, (Hypothesis)), type(hyp) | |
| # remove sos/eos and get hyps | |
| last_pos = -1 | |
| if isinstance(hyp.yseq, list): | |
| token_int = hyp.yseq[1:last_pos] | |
| else: | |
| token_int = hyp.yseq[1:last_pos].tolist() | |
| # remove blank symbol id and unk id, which is assumed to be 0 | |
| # and 1 | |
| token_int = list(filter(lambda x: x != 0 and x != 1, token_int)) | |
| hyps.append(token_int) | |
| return hyps | |
| def paraformer_beam_search( | |
| self, | |
| speech: torch.Tensor, | |
| speech_lengths: torch.Tensor, | |
| beam_search: torch.nn.Module = None, | |
| decoding_chunk_size: int = -1, | |
| num_decoding_left_chunks: int = -1, | |
| simulate_streaming: bool = False, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Apply beam search on attention decoder | |
| Args: | |
| speech (torch.Tensor): (batch, max_len, feat_dim) | |
| speech_lengths (torch.Tensor): (batch, ) | |
| beam_search (torch.nn.Moudle): beam search module | |
| decoding_chunk_size (int): decoding chunk for dynamic chunk | |
| trained model. | |
| <0: for decoding, use full chunk. | |
| >0: for decoding, use fixed chunk size as set. | |
| 0: used for training, it's prohibited here | |
| simulate_streaming (bool): whether do encoder forward in a | |
| streaming fashion | |
| Returns: | |
| torch.Tensor: decoding result, (batch, max_result_len) | |
| """ | |
| assert speech.shape[0] == speech_lengths.shape[0] | |
| assert decoding_chunk_size != 0 | |
| device = speech.device | |
| batch_size = speech.shape[0] | |
| # Let's assume B = batch_size and N = beam_size | |
| # 1. Encoder | |
| encoder_out, encoder_mask = self._forward_encoder( | |
| speech, | |
| speech_lengths, | |
| decoding_chunk_size, | |
| num_decoding_left_chunks, | |
| simulate_streaming, | |
| ) # (B, maxlen, encoder_dim) | |
| encoder_out_lens = encoder_mask.squeeze(1).sum(1) | |
| # 2. Predictor | |
| predictor_outs = self.calc_predictor(encoder_out, encoder_mask) | |
| pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = ( | |
| predictor_outs[0], | |
| predictor_outs[1], | |
| predictor_outs[2], | |
| predictor_outs[3], | |
| ) | |
| pre_token_length = pre_token_length.round().long() | |
| if torch.max(pre_token_length) < 1: | |
| return torch.tensor([]), torch.tensor([]) | |
| # 2. Decoder forward | |
| decoder_outs = self.cal_decoder_with_predictor( | |
| encoder_out, encoder_out_lens, pre_acoustic_embeds, pre_token_length | |
| ) | |
| decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] | |
| hyps = [] | |
| b, n, d = decoder_out.size() | |
| for i in range(b): | |
| x = encoder_out[i, : encoder_out_lens[i], :] | |
| am_scores = decoder_out[i, : pre_token_length[i], :] | |
| if beam_search is not None: | |
| nbest_hyps = beam_search(x=x, am_scores=am_scores) | |
| nbest_hyps = nbest_hyps[:1] | |
| else: | |
| yseq = am_scores.argmax(dim=-1) | |
| score = am_scores.max(dim=-1)[0] | |
| score = torch.sum(score, dim=-1) | |
| # pad with mask tokens to ensure compatibility with sos/eos | |
| # tokens | |
| yseq = torch.tensor( | |
| [self.sos] + yseq.tolist() + [self.eos], device=yseq.device | |
| ) | |
| nbest_hyps = [Hypothesis(yseq=yseq, score=score)] | |
| for hyp in nbest_hyps: | |
| assert isinstance(hyp, (Hypothesis)), type(hyp) | |
| # remove sos/eos and get hyps | |
| last_pos = -1 | |
| if isinstance(hyp.yseq, list): | |
| token_int = hyp.yseq[1:last_pos] | |
| else: | |
| token_int = hyp.yseq[1:last_pos].tolist() | |
| # remove blank symbol id and unk id, which is assumed to be 0 | |
| # and 1 | |
| token_int = list(filter(lambda x: x != 0 and x != 1, token_int)) | |
| hyps.append(token_int) | |
| return hyps | |