Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	| # This module is from [WeNet](https://github.com/wenet-e2e/wenet). | |
| # ## Citations | |
| # ```bibtex | |
| # @inproceedings{yao2021wenet, | |
| # title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit}, | |
| # author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin}, | |
| # booktitle={Proc. Interspeech}, | |
| # year={2021}, | |
| # address={Brno, Czech Republic }, | |
| # organization={IEEE} | |
| # } | |
| # @article{zhang2022wenet, | |
| # title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit}, | |
| # author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei}, | |
| # journal={arXiv preprint arXiv:2203.15455}, | |
| # year={2022} | |
| # } | |
| # | |
| import torch | |
| import torch.nn as nn | |
| from typing import Tuple, Union, Optional, List | |
| from modules.wenet_extractor.squeezeformer.subsampling import ( | |
| DepthwiseConv2dSubsampling4, | |
| TimeReductionLayer1D, | |
| TimeReductionLayer2D, | |
| TimeReductionLayerStream, | |
| ) | |
| from modules.wenet_extractor.squeezeformer.encoder_layer import ( | |
| SqueezeformerEncoderLayer, | |
| ) | |
| from modules.wenet_extractor.transformer.embedding import RelPositionalEncoding | |
| from modules.wenet_extractor.transformer.attention import MultiHeadedAttention | |
| from modules.wenet_extractor.squeezeformer.attention import ( | |
| RelPositionMultiHeadedAttention, | |
| ) | |
| from modules.wenet_extractor.squeezeformer.positionwise_feed_forward import ( | |
| PositionwiseFeedForward, | |
| ) | |
| from modules.wenet_extractor.squeezeformer.convolution import ConvolutionModule | |
| from modules.wenet_extractor.utils.mask import make_pad_mask, add_optional_chunk_mask | |
| from modules.wenet_extractor.utils.common import get_activation | |
| class SqueezeformerEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| input_size: int = 80, | |
| encoder_dim: int = 256, | |
| output_size: int = 256, | |
| attention_heads: int = 4, | |
| num_blocks: int = 12, | |
| reduce_idx: Optional[Union[int, List[int]]] = 5, | |
| recover_idx: Optional[Union[int, List[int]]] = 11, | |
| feed_forward_expansion_factor: int = 4, | |
| dw_stride: bool = False, | |
| input_dropout_rate: float = 0.1, | |
| pos_enc_layer_type: str = "rel_pos", | |
| time_reduction_layer_type: str = "conv1d", | |
| do_rel_shift: bool = True, | |
| feed_forward_dropout_rate: float = 0.1, | |
| attention_dropout_rate: float = 0.1, | |
| cnn_module_kernel: int = 31, | |
| cnn_norm_type: str = "batch_norm", | |
| dropout: float = 0.1, | |
| causal: bool = False, | |
| adaptive_scale: bool = True, | |
| activation_type: str = "swish", | |
| init_weights: bool = True, | |
| global_cmvn: torch.nn.Module = None, | |
| normalize_before: bool = False, | |
| use_dynamic_chunk: bool = False, | |
| concat_after: bool = False, | |
| static_chunk_size: int = 0, | |
| use_dynamic_left_chunk: bool = False, | |
| ): | |
| """Construct SqueezeformerEncoder | |
| Args: | |
| input_size to use_dynamic_chunk, see in Transformer BaseEncoder. | |
| encoder_dim (int): The hidden dimension of encoder layer. | |
| output_size (int): The output dimension of final projection layer. | |
| attention_heads (int): Num of attention head in attention module. | |
| num_blocks (int): Num of encoder layers. | |
| reduce_idx Optional[Union[int, List[int]]]: | |
| reduce layer index, from 40ms to 80ms per frame. | |
| recover_idx Optional[Union[int, List[int]]]: | |
| recover layer index, from 80ms to 40ms per frame. | |
| feed_forward_expansion_factor (int): Enlarge coefficient of FFN. | |
| dw_stride (bool): Whether do depthwise convolution | |
| on subsampling module. | |
| input_dropout_rate (float): Dropout rate of input projection layer. | |
| pos_enc_layer_type (str): Self attention type. | |
| time_reduction_layer_type (str): Conv1d or Conv2d reduction layer. | |
| do_rel_shift (bool): Whether to do relative shift | |
| operation on rel-attention module. | |
| cnn_module_kernel (int): Kernel size of CNN module. | |
| activation_type (str): Encoder activation function type. | |
| use_cnn_module (bool): Whether to use convolution module. | |
| cnn_module_kernel (int): Kernel size of convolution module. | |
| adaptive_scale (bool): Whether to use adaptive scale. | |
| init_weights (bool): Whether to initialize weights. | |
| causal (bool): whether to use causal convolution or not. | |
| """ | |
| super(SqueezeformerEncoder, self).__init__() | |
| self.global_cmvn = global_cmvn | |
| self.reduce_idx: Optional[Union[int, List[int]]] = ( | |
| [reduce_idx] if type(reduce_idx) == int else reduce_idx | |
| ) | |
| self.recover_idx: Optional[Union[int, List[int]]] = ( | |
| [recover_idx] if type(recover_idx) == int else recover_idx | |
| ) | |
| self.check_ascending_list() | |
| if reduce_idx is None: | |
| self.time_reduce = None | |
| else: | |
| if recover_idx is None: | |
| self.time_reduce = "normal" # no recovery at the end | |
| else: | |
| self.time_reduce = "recover" # recovery at the end | |
| assert len(self.reduce_idx) == len(self.recover_idx) | |
| self.reduce_stride = 2 | |
| self._output_size = output_size | |
| self.normalize_before = normalize_before | |
| self.static_chunk_size = static_chunk_size | |
| self.use_dynamic_chunk = use_dynamic_chunk | |
| self.use_dynamic_left_chunk = use_dynamic_left_chunk | |
| self.pos_enc_layer_type = pos_enc_layer_type | |
| activation = get_activation(activation_type) | |
| # self-attention module definition | |
| if pos_enc_layer_type != "rel_pos": | |
| encoder_selfattn_layer = MultiHeadedAttention | |
| encoder_selfattn_layer_args = ( | |
| attention_heads, | |
| output_size, | |
| attention_dropout_rate, | |
| ) | |
| else: | |
| encoder_selfattn_layer = RelPositionMultiHeadedAttention | |
| encoder_selfattn_layer_args = ( | |
| attention_heads, | |
| encoder_dim, | |
| attention_dropout_rate, | |
| do_rel_shift, | |
| adaptive_scale, | |
| init_weights, | |
| ) | |
| # feed-forward module definition | |
| positionwise_layer = PositionwiseFeedForward | |
| positionwise_layer_args = ( | |
| encoder_dim, | |
| encoder_dim * feed_forward_expansion_factor, | |
| feed_forward_dropout_rate, | |
| activation, | |
| adaptive_scale, | |
| init_weights, | |
| ) | |
| # convolution module definition | |
| convolution_layer = ConvolutionModule | |
| convolution_layer_args = ( | |
| encoder_dim, | |
| cnn_module_kernel, | |
| activation, | |
| cnn_norm_type, | |
| causal, | |
| True, | |
| adaptive_scale, | |
| init_weights, | |
| ) | |
| self.embed = DepthwiseConv2dSubsampling4( | |
| 1, | |
| encoder_dim, | |
| RelPositionalEncoding(encoder_dim, dropout_rate=0.1), | |
| dw_stride, | |
| input_size, | |
| input_dropout_rate, | |
| init_weights, | |
| ) | |
| self.preln = nn.LayerNorm(encoder_dim) | |
| self.encoders = torch.nn.ModuleList( | |
| [ | |
| SqueezeformerEncoderLayer( | |
| encoder_dim, | |
| encoder_selfattn_layer(*encoder_selfattn_layer_args), | |
| positionwise_layer(*positionwise_layer_args), | |
| convolution_layer(*convolution_layer_args), | |
| positionwise_layer(*positionwise_layer_args), | |
| normalize_before, | |
| dropout, | |
| concat_after, | |
| ) | |
| for _ in range(num_blocks) | |
| ] | |
| ) | |
| if time_reduction_layer_type == "conv1d": | |
| time_reduction_layer = TimeReductionLayer1D | |
| time_reduction_layer_args = { | |
| "channel": encoder_dim, | |
| "out_dim": encoder_dim, | |
| } | |
| elif time_reduction_layer_type == "stream": | |
| time_reduction_layer = TimeReductionLayerStream | |
| time_reduction_layer_args = { | |
| "channel": encoder_dim, | |
| "out_dim": encoder_dim, | |
| } | |
| else: | |
| time_reduction_layer = TimeReductionLayer2D | |
| time_reduction_layer_args = {"encoder_dim": encoder_dim} | |
| self.time_reduction_layer = time_reduction_layer(**time_reduction_layer_args) | |
| self.time_recover_layer = nn.Linear(encoder_dim, encoder_dim) | |
| self.final_proj = None | |
| if output_size != encoder_dim: | |
| self.final_proj = nn.Linear(encoder_dim, output_size) | |
| def output_size(self) -> int: | |
| return self._output_size | |
| def forward( | |
| self, | |
| xs: torch.Tensor, | |
| xs_lens: torch.Tensor, | |
| decoding_chunk_size: int = 0, | |
| num_decoding_left_chunks: int = -1, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| T = xs.size(1) | |
| masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) | |
| if self.global_cmvn is not None: | |
| xs = self.global_cmvn(xs) | |
| xs, pos_emb, masks = self.embed(xs, masks) | |
| mask_pad = masks # (B, 1, T/subsample_rate) | |
| chunk_masks = add_optional_chunk_mask( | |
| xs, | |
| masks, | |
| self.use_dynamic_chunk, | |
| self.use_dynamic_left_chunk, | |
| decoding_chunk_size, | |
| self.static_chunk_size, | |
| num_decoding_left_chunks, | |
| ) | |
| xs_lens = mask_pad.squeeze(1).sum(1) | |
| xs = self.preln(xs) | |
| recover_activations: List[ | |
| Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | |
| ] = [] | |
| index = 0 | |
| for i, layer in enumerate(self.encoders): | |
| if self.reduce_idx is not None: | |
| if self.time_reduce is not None and i in self.reduce_idx: | |
| recover_activations.append((xs, chunk_masks, pos_emb, mask_pad)) | |
| xs, xs_lens, chunk_masks, mask_pad = self.time_reduction_layer( | |
| xs, xs_lens, chunk_masks, mask_pad | |
| ) | |
| pos_emb = pos_emb[:, ::2, :] | |
| index += 1 | |
| if self.recover_idx is not None: | |
| if self.time_reduce == "recover" and i in self.recover_idx: | |
| index -= 1 | |
| ( | |
| recover_tensor, | |
| recover_chunk_masks, | |
| recover_pos_emb, | |
| recover_mask_pad, | |
| ) = recover_activations[index] | |
| # recover output length for ctc decode | |
| xs = xs.unsqueeze(2).repeat(1, 1, 2, 1).flatten(1, 2) | |
| xs = self.time_recover_layer(xs) | |
| recoverd_t = recover_tensor.size(1) | |
| xs = recover_tensor + xs[:, :recoverd_t, :].contiguous() | |
| chunk_masks = recover_chunk_masks | |
| pos_emb = recover_pos_emb | |
| mask_pad = recover_mask_pad | |
| xs = xs.masked_fill(~mask_pad[:, 0, :].unsqueeze(-1), 0.0) | |
| xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) | |
| if self.final_proj is not None: | |
| xs = self.final_proj(xs) | |
| return xs, masks | |
| def check_ascending_list(self): | |
| if self.reduce_idx is not None: | |
| assert self.reduce_idx == sorted( | |
| self.reduce_idx | |
| ), "reduce_idx should be int or ascending list" | |
| if self.recover_idx is not None: | |
| assert self.recover_idx == sorted( | |
| self.recover_idx | |
| ), "recover_idx should be int or ascending list" | |
| def calculate_downsampling_factor(self, i: int) -> int: | |
| if self.reduce_idx is None: | |
| return 1 | |
| else: | |
| reduce_exp, recover_exp = 0, 0 | |
| for exp, rd_idx in enumerate(self.reduce_idx): | |
| if i >= rd_idx: | |
| reduce_exp = exp + 1 | |
| if self.recover_idx is not None: | |
| for exp, rc_idx in enumerate(self.recover_idx): | |
| if i >= rc_idx: | |
| recover_exp = exp + 1 | |
| return int(2 ** (reduce_exp - recover_exp)) | |
| def forward_chunk( | |
| self, | |
| xs: torch.Tensor, | |
| offset: int, | |
| required_cache_size: int, | |
| att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), | |
| cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), | |
| att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ Forward just one chunk | |
| Args: | |
| xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim), | |
| where `time == (chunk_size - 1) * subsample_rate + \ | |
| subsample.right_context + 1` | |
| offset (int): current offset in encoder output time stamp | |
| required_cache_size (int): cache size required for next chunk | |
| compuation | |
| >=0: actual cache size | |
| <0: means all history cache is required | |
| att_cache (torch.Tensor): cache tensor for KEY & VALUE in | |
| transformer/conformer attention, with shape | |
| (elayers, head, cache_t1, d_k * 2), where | |
| `head * d_k == hidden-dim` and | |
| `cache_t1 == chunk_size * num_decoding_left_chunks`. | |
| cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer, | |
| (elayers, b=1, hidden-dim, cache_t2), where | |
| `cache_t2 == cnn.lorder - 1` | |
| Returns: | |
| torch.Tensor: output of current input xs, | |
| with shape (b=1, chunk_size, hidden-dim). | |
| torch.Tensor: new attention cache required for next chunk, with | |
| dynamic shape (elayers, head, ?, d_k * 2) | |
| depending on required_cache_size. | |
| torch.Tensor: new conformer cnn cache required for next chunk, with | |
| same shape as the original cnn_cache. | |
| """ | |
| assert xs.size(0) == 1 | |
| # tmp_masks is just for interface compatibility | |
| tmp_masks = torch.ones(1, xs.size(1), device=xs.device, dtype=torch.bool) | |
| tmp_masks = tmp_masks.unsqueeze(1) | |
| if self.global_cmvn is not None: | |
| xs = self.global_cmvn(xs) | |
| # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim) | |
| xs, pos_emb, _ = self.embed(xs, tmp_masks, offset) | |
| # NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim) | |
| elayers, cache_t1 = att_cache.size(0), att_cache.size(2) | |
| chunk_size = xs.size(1) | |
| attention_key_size = cache_t1 + chunk_size | |
| pos_emb = self.embed.position_encoding( | |
| offset=offset - cache_t1, size=attention_key_size | |
| ) | |
| if required_cache_size < 0: | |
| next_cache_start = 0 | |
| elif required_cache_size == 0: | |
| next_cache_start = attention_key_size | |
| else: | |
| next_cache_start = max(attention_key_size - required_cache_size, 0) | |
| r_att_cache = [] | |
| r_cnn_cache = [] | |
| mask_pad = torch.ones(1, xs.size(1), device=xs.device, dtype=torch.bool) | |
| mask_pad = mask_pad.unsqueeze(1) | |
| max_att_len: int = 0 | |
| recover_activations: List[ | |
| Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | |
| ] = [] | |
| index = 0 | |
| xs_lens = torch.tensor([xs.size(1)], device=xs.device, dtype=torch.int) | |
| xs = self.preln(xs) | |
| for i, layer in enumerate(self.encoders): | |
| # NOTE(xcsong): Before layer.forward | |
| # shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2), | |
| # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2) | |
| if self.reduce_idx is not None: | |
| if self.time_reduce is not None and i in self.reduce_idx: | |
| recover_activations.append((xs, att_mask, pos_emb, mask_pad)) | |
| xs, xs_lens, att_mask, mask_pad = self.time_reduction_layer( | |
| xs, xs_lens, att_mask, mask_pad | |
| ) | |
| pos_emb = pos_emb[:, ::2, :] | |
| index += 1 | |
| if self.recover_idx is not None: | |
| if self.time_reduce == "recover" and i in self.recover_idx: | |
| index -= 1 | |
| ( | |
| recover_tensor, | |
| recover_att_mask, | |
| recover_pos_emb, | |
| recover_mask_pad, | |
| ) = recover_activations[index] | |
| # recover output length for ctc decode | |
| xs = xs.unsqueeze(2).repeat(1, 1, 2, 1).flatten(1, 2) | |
| xs = self.time_recover_layer(xs) | |
| recoverd_t = recover_tensor.size(1) | |
| xs = recover_tensor + xs[:, :recoverd_t, :].contiguous() | |
| att_mask = recover_att_mask | |
| pos_emb = recover_pos_emb | |
| mask_pad = recover_mask_pad | |
| if att_mask.size(1) != 0: | |
| xs = xs.masked_fill(~att_mask[:, 0, :].unsqueeze(-1), 0.0) | |
| factor = self.calculate_downsampling_factor(i) | |
| xs, _, new_att_cache, new_cnn_cache = layer( | |
| xs, | |
| att_mask, | |
| pos_emb, | |
| att_cache=att_cache[i : i + 1][:, :, ::factor, :][ | |
| :, :, : pos_emb.size(1) - xs.size(1), : | |
| ] | |
| if elayers > 0 | |
| else att_cache[:, :, ::factor, :], | |
| cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache, | |
| ) | |
| # NOTE(xcsong): After layer.forward | |
| # shape(new_att_cache) is (1, head, attention_key_size, d_k * 2), | |
| # shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2) | |
| cached_att = new_att_cache[:, :, next_cache_start // factor :, :] | |
| cached_cnn = new_cnn_cache.unsqueeze(0) | |
| cached_att = ( | |
| cached_att.unsqueeze(3).repeat(1, 1, 1, factor, 1).flatten(2, 3) | |
| ) | |
| if i == 0: | |
| # record length for the first block as max length | |
| max_att_len = cached_att.size(2) | |
| r_att_cache.append(cached_att[:, :, :max_att_len, :]) | |
| r_cnn_cache.append(cached_cnn) | |
| # NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2), | |
| # ? may be larger than cache_t1, it depends on required_cache_size | |
| r_att_cache = torch.cat(r_att_cache, dim=0) | |
| # NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2) | |
| r_cnn_cache = torch.cat(r_cnn_cache, dim=0) | |
| if self.final_proj is not None: | |
| xs = self.final_proj(xs) | |
| return (xs, r_att_cache, r_cnn_cache) | |
| def forward_chunk_by_chunk( | |
| self, | |
| xs: torch.Tensor, | |
| decoding_chunk_size: int, | |
| num_decoding_left_chunks: int = -1, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Forward input chunk by chunk with chunk_size like a streaming | |
| fashion | |
| Here we should pay special attention to computation cache in the | |
| streaming style forward chunk by chunk. Three things should be taken | |
| into account for computation in the current network: | |
| 1. transformer/conformer encoder layers output cache | |
| 2. convolution in conformer | |
| 3. convolution in subsampling | |
| However, we don't implement subsampling cache for: | |
| 1. We can control subsampling module to output the right result by | |
| overlapping input instead of cache left context, even though it | |
| wastes some computation, but subsampling only takes a very | |
| small fraction of computation in the whole model. | |
| 2. Typically, there are several covolution layers with subsampling | |
| in subsampling module, it is tricky and complicated to do cache | |
| with different convolution layers with different subsampling | |
| rate. | |
| 3. Currently, nn.Sequential is used to stack all the convolution | |
| layers in subsampling, we need to rewrite it to make it work | |
| with cache, which is not prefered. | |
| Args: | |
| xs (torch.Tensor): (1, max_len, dim) | |
| chunk_size (int): decoding chunk size | |
| """ | |
| assert decoding_chunk_size > 0 | |
| # The model is trained by static or dynamic chunk | |
| assert self.static_chunk_size > 0 or self.use_dynamic_chunk | |
| subsampling = self.embed.subsampling_rate | |
| context = self.embed.right_context + 1 # Add current frame | |
| stride = subsampling * decoding_chunk_size | |
| decoding_window = (decoding_chunk_size - 1) * subsampling + context | |
| num_frames = xs.size(1) | |
| att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device) | |
| cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device) | |
| outputs = [] | |
| offset = 0 | |
| required_cache_size = decoding_chunk_size * num_decoding_left_chunks | |
| # Feed forward overlap input step by step | |
| for cur in range(0, num_frames - context + 1, stride): | |
| end = min(cur + decoding_window, num_frames) | |
| chunk_xs = xs[:, cur:end, :] | |
| (y, att_cache, cnn_cache) = self.forward_chunk( | |
| chunk_xs, offset, required_cache_size, att_cache, cnn_cache | |
| ) | |
| outputs.append(y) | |
| offset += y.size(1) | |
| ys = torch.cat(outputs, 1) | |
| masks = torch.ones((1, 1, ys.size(1)), device=ys.device, dtype=torch.bool) | |
| return ys, masks | |