Spaces:
Running
Running
| from typing import List, Tuple | |
| import torch | |
| from modules.wenet_extractor.utils.common import log_add | |
| class Sequence: | |
| __slots__ = {"hyp", "score", "cache"} | |
| def __init__( | |
| self, | |
| hyp: List[torch.Tensor], | |
| score, | |
| cache: List[torch.Tensor], | |
| ): | |
| self.hyp = hyp | |
| self.score = score | |
| self.cache = cache | |
| class PrefixBeamSearch: | |
| def __init__(self, encoder, predictor, joint, ctc, blank): | |
| self.encoder = encoder | |
| self.predictor = predictor | |
| self.joint = joint | |
| self.ctc = ctc | |
| self.blank = blank | |
| def forward_decoder_one_step( | |
| self, encoder_x: torch.Tensor, pre_t: torch.Tensor, cache: List[torch.Tensor] | |
| ) -> Tuple[torch.Tensor, List[torch.Tensor]]: | |
| padding = torch.zeros(pre_t.size(0), 1, device=encoder_x.device) | |
| pre_t, new_cache = self.predictor.forward_step( | |
| pre_t.unsqueeze(-1), padding, cache | |
| ) | |
| x = self.joint(encoder_x, pre_t) # [beam, 1, 1, vocab] | |
| x = x.log_softmax(dim=-1) | |
| return x, new_cache | |
| def prefix_beam_search( | |
| self, | |
| speech: torch.Tensor, | |
| speech_lengths: torch.Tensor, | |
| decoding_chunk_size: int = -1, | |
| beam_size: int = 5, | |
| num_decoding_left_chunks: int = -1, | |
| simulate_streaming: bool = False, | |
| ctc_weight: float = 0.3, | |
| transducer_weight: float = 0.7, | |
| ): | |
| """prefix beam search | |
| also see wenet.transducer.transducer.beam_search | |
| """ | |
| assert speech.shape[0] == speech_lengths.shape[0] | |
| assert decoding_chunk_size != 0 | |
| device = speech.device | |
| batch_size = speech.shape[0] | |
| assert batch_size == 1 | |
| # 1. Encoder | |
| encoder_out, _ = self.encoder( | |
| speech, speech_lengths, decoding_chunk_size, num_decoding_left_chunks | |
| ) # (B, maxlen, encoder_dim) | |
| maxlen = encoder_out.size(1) | |
| ctc_probs = self.ctc.log_softmax(encoder_out).squeeze(0) | |
| beam_init: List[Sequence] = [] | |
| # 2. init beam using Sequence to save beam unit | |
| cache = self.predictor.init_state(1, method="zero", device=device) | |
| beam_init.append(Sequence(hyp=[self.blank], score=0.0, cache=cache)) | |
| # 3. start decoding (notice: we use breathwise first searching) | |
| # !!!! In this decoding method: one frame do not output multi units. !!!! | |
| # !!!! Experiments show that this strategy has little impact !!!! | |
| for i in range(maxlen): | |
| # 3.1 building input | |
| # decoder taking the last token to predict the next token | |
| input_hyp = [s.hyp[-1] for s in beam_init] | |
| input_hyp_tensor = torch.tensor(input_hyp, dtype=torch.int, device=device) | |
| # building statement from beam | |
| cache_batch = self.predictor.cache_to_batch([s.cache for s in beam_init]) | |
| # build score tensor to do torch.add() function | |
| scores = torch.tensor([s.score for s in beam_init]).to(device) | |
| # 3.2 forward decoder | |
| logp, new_cache = self.forward_decoder_one_step( | |
| encoder_out[:, i, :].unsqueeze(1), | |
| input_hyp_tensor, | |
| cache_batch, | |
| ) # logp: (N, 1, 1, vocab_size) | |
| logp = logp.squeeze(1).squeeze(1) # logp: (N, vocab_size) | |
| new_cache = self.predictor.batch_to_cache(new_cache) | |
| # 3.3 shallow fusion for transducer score | |
| # and ctc score where we can also add the LM score | |
| logp = torch.log( | |
| torch.add( | |
| transducer_weight * torch.exp(logp), | |
| ctc_weight * torch.exp(ctc_probs[i].unsqueeze(0)), | |
| ) | |
| ) | |
| # 3.4 first beam prune | |
| top_k_logp, top_k_index = logp.topk(beam_size) # (N, N) | |
| scores = torch.add(scores.unsqueeze(1), top_k_logp) | |
| # 3.5 generate new beam (N*N) | |
| beam_A = [] | |
| for j in range(len(beam_init)): | |
| # update seq | |
| base_seq = beam_init[j] | |
| for t in range(beam_size): | |
| # blank: only update the score | |
| if top_k_index[j, t] == self.blank: | |
| new_seq = Sequence( | |
| hyp=base_seq.hyp.copy(), | |
| score=scores[j, t].item(), | |
| cache=base_seq.cache, | |
| ) | |
| beam_A.append(new_seq) | |
| # other unit: update hyp score statement and last | |
| else: | |
| hyp_new = base_seq.hyp.copy() | |
| hyp_new.append(top_k_index[j, t].item()) | |
| new_seq = Sequence( | |
| hyp=hyp_new, score=scores[j, t].item(), cache=new_cache[j] | |
| ) | |
| beam_A.append(new_seq) | |
| # 3.6 prefix fusion | |
| fusion_A = [beam_A[0]] | |
| for j in range(1, len(beam_A)): | |
| s1 = beam_A[j] | |
| if_do_append = True | |
| for t in range(len(fusion_A)): | |
| # notice: A_ can not fusion with A | |
| if s1.hyp == fusion_A[t].hyp: | |
| fusion_A[t].score = log_add([fusion_A[t].score, s1.score]) | |
| if_do_append = False | |
| break | |
| if if_do_append: | |
| fusion_A.append(s1) | |
| # 4. second pruned | |
| fusion_A.sort(key=lambda x: x.score, reverse=True) | |
| beam_init = fusion_A[:beam_size] | |
| return beam_init, encoder_out | |