Spaces:
Runtime error
Runtime error
| from typing import Optional, Dict | |
| from torch import Tensor | |
| import torch | |
| def waitk_p_choose( | |
| tgt_len: int, | |
| src_len: int, | |
| bsz: int, | |
| waitk_lagging: int, | |
| key_padding_mask: Optional[Tensor] = None, | |
| incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None | |
| ): | |
| max_src_len = src_len | |
| if incremental_state is not None: | |
| # Retrieve target length from incremental states | |
| # For inference the length of query is always 1 | |
| max_tgt_len = incremental_state["steps"]["tgt"] | |
| assert max_tgt_len is not None | |
| max_tgt_len = int(max_tgt_len) | |
| else: | |
| max_tgt_len = tgt_len | |
| if max_src_len < waitk_lagging: | |
| if incremental_state is not None: | |
| max_tgt_len = 1 | |
| return torch.zeros( | |
| bsz, max_tgt_len, max_src_len | |
| ) | |
| # Assuming the p_choose looks like this for wait k=3 | |
| # src_len = 6, max_tgt_len = 5 | |
| # [0, 0, 1, 0, 0, 0, 0] | |
| # [0, 0, 0, 1, 0, 0, 0] | |
| # [0, 0, 0, 0, 1, 0, 0] | |
| # [0, 0, 0, 0, 0, 1, 0] | |
| # [0, 0, 0, 0, 0, 0, 1] | |
| # linearize the p_choose matrix: | |
| # [0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0...] | |
| # The indices of linearized matrix that equals 1 is | |
| # 2 + 6 * 0 | |
| # 3 + 6 * 1 | |
| # ... | |
| # n + src_len * n + k - 1 = n * (src_len + 1) + k - 1 | |
| # n from 0 to max_tgt_len - 1 | |
| # | |
| # First, generate the indices (activate_indices_offset: bsz, max_tgt_len) | |
| # Second, scatter a zeros tensor (bsz, max_tgt_len * src_len) | |
| # with activate_indices_offset | |
| # Third, resize the tensor to (bsz, max_tgt_len, src_len) | |
| activate_indices_offset = ( | |
| ( | |
| torch.arange(max_tgt_len) * (max_src_len + 1) | |
| + waitk_lagging - 1 | |
| ) | |
| .unsqueeze(0) | |
| .expand(bsz, max_tgt_len) | |
| .long() | |
| ) | |
| if key_padding_mask is not None: | |
| if key_padding_mask[:, 0].any(): | |
| # Left padding | |
| activate_indices_offset += ( | |
| key_padding_mask.sum(dim=1, keepdim=True) | |
| ) | |
| # Need to clamp the indices that are too large | |
| activate_indices_offset = ( | |
| activate_indices_offset | |
| .clamp( | |
| 0, | |
| min( | |
| [ | |
| max_tgt_len, | |
| max_src_len - waitk_lagging + 1 | |
| ] | |
| ) * max_src_len - 1 | |
| ) | |
| ) | |
| p_choose = torch.zeros(bsz, max_tgt_len * max_src_len) | |
| p_choose = p_choose.scatter( | |
| 1, | |
| activate_indices_offset, | |
| 1.0 | |
| ).view(bsz, max_tgt_len, max_src_len) | |
| if key_padding_mask is not None: | |
| p_choose = p_choose.to(key_padding_mask) | |
| p_choose = p_choose.masked_fill(key_padding_mask.unsqueeze(1), 0) | |
| if incremental_state is not None: | |
| p_choose = p_choose[:, -1:] | |
| return p_choose.float() | |
| def learnable_p_choose( | |
| energy, | |
| noise_mean: float = 0.0, | |
| noise_var: float = 0.0, | |
| training: bool = True | |
| ): | |
| """ | |
| Calculating step wise prob for reading and writing | |
| 1 to read, 0 to write | |
| energy: bsz, tgt_len, src_len | |
| """ | |
| noise = 0 | |
| if training: | |
| # add noise here to encourage discretness | |
| noise = ( | |
| torch.normal(noise_mean, noise_var, energy.size()) | |
| .type_as(energy) | |
| .to(energy.device) | |
| ) | |
| p_choose = torch.sigmoid(energy + noise) | |
| # p_choose: bsz * self.num_heads, tgt_len, src_len | |
| return p_choose | |