Spaces:
Running
on
T4
Running
on
T4
Create viterbi_decoding.py
Browse files- viterbi_decoding.py +137 -0
viterbi_decoding.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
V_NEGATIVE_NUM = -3.4e38
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def viterbi_decoding(log_probs_batch, y_batch, T_batch, U_batch, viterbi_device):
|
| 20 |
+
"""
|
| 21 |
+
Do Viterbi decoding with an efficient algorithm (the only for-loop in the 'forward pass' is over the time dimension).
|
| 22 |
+
Args:
|
| 23 |
+
log_probs_batch: tensor of shape (B, T_max, V). The parts of log_probs_batch which are 'padding' are filled
|
| 24 |
+
with 'V_NEGATIVE_NUM' - a large negative number which represents a very low probability.
|
| 25 |
+
y_batch: tensor of shape (B, U_max) - contains token IDs including blanks in every other position. The parts of
|
| 26 |
+
y_batch which are padding are filled with the number 'V'. V = the number of tokens in the vocabulary + 1 for
|
| 27 |
+
the blank token.
|
| 28 |
+
T_batch: tensor of shape (B, 1) - contains the durations of the log_probs_batch (so we can ignore the
|
| 29 |
+
parts of log_probs_batch which are padding)
|
| 30 |
+
U_batch: tensor of shape (B, 1) - contains the lengths of y_batch (so we can ignore the parts of y_batch
|
| 31 |
+
which are padding).
|
| 32 |
+
viterbi_device: the torch device on which Viterbi decoding will be done.
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
alignments_batch: list of lists containing locations for the tokens we align to at each timestep.
|
| 36 |
+
Looks like: [[0, 0, 1, 2, 2, 3, 3, ..., ], ..., [0, 1, 2, 2, 2, 3, 4, ....]].
|
| 37 |
+
Each list inside alignments_batch is of length T_batch[location of utt in batch].
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
B, T_max, _ = log_probs_batch.shape
|
| 41 |
+
U_max = y_batch.shape[1]
|
| 42 |
+
|
| 43 |
+
# transfer all tensors to viterbi_device
|
| 44 |
+
log_probs_batch = log_probs_batch.to(viterbi_device)
|
| 45 |
+
y_batch = y_batch.to(viterbi_device)
|
| 46 |
+
T_batch = T_batch.to(viterbi_device)
|
| 47 |
+
U_batch = U_batch.to(viterbi_device)
|
| 48 |
+
|
| 49 |
+
# make tensor that we will put at timesteps beyond the duration of the audio
|
| 50 |
+
padding_for_log_probs = V_NEGATIVE_NUM * torch.ones((B, T_max, 1), device=viterbi_device)
|
| 51 |
+
# make log_probs_padded tensor of shape (B, T_max, V +1 ) where all of
|
| 52 |
+
# log_probs_padded[:,:,-1] is the 'V_NEGATIVE_NUM'
|
| 53 |
+
log_probs_padded = torch.cat((log_probs_batch, padding_for_log_probs), dim=2)
|
| 54 |
+
|
| 55 |
+
# initialize v_prev - tensor of previous timestep's viterbi probabilies, of shape (B, U_max)
|
| 56 |
+
v_prev = V_NEGATIVE_NUM * torch.ones((B, U_max), device=viterbi_device)
|
| 57 |
+
v_prev[:, :2] = torch.gather(input=log_probs_padded[:, 0, :], dim=1, index=y_batch[:, :2])
|
| 58 |
+
|
| 59 |
+
# initialize backpointers_rel - which contains values like 0 to indicate the backpointer is to the same u index,
|
| 60 |
+
# 1 to indicate the backpointer pointing to the u-1 index and 2 to indicate the backpointer is pointing to the u-2 index
|
| 61 |
+
backpointers_rel = -99 * torch.ones((B, T_max, U_max), dtype=torch.int8, device=viterbi_device)
|
| 62 |
+
|
| 63 |
+
# Make a letter_repetition_mask the same shape as y_batch
|
| 64 |
+
# the letter_repetition_mask will have 'True' where the token (including blanks) is the same
|
| 65 |
+
# as the token two places before it in the ground truth (and 'False everywhere else).
|
| 66 |
+
# We will use letter_repetition_mask to determine whether the Viterbi algorithm needs to look two tokens back or
|
| 67 |
+
# three tokens back
|
| 68 |
+
y_shifted_left = torch.roll(y_batch, shifts=2, dims=1)
|
| 69 |
+
letter_repetition_mask = y_batch - y_shifted_left
|
| 70 |
+
letter_repetition_mask[:, :2] = 1 # make sure dont apply mask to first 2 tokens
|
| 71 |
+
letter_repetition_mask = letter_repetition_mask == 0
|
| 72 |
+
|
| 73 |
+
for t in range(1, T_max):
|
| 74 |
+
|
| 75 |
+
# e_current is a tensor of shape (B, U_max) of the log probs of every possible token at the current timestep
|
| 76 |
+
e_current = torch.gather(input=log_probs_padded[:, t, :], dim=1, index=y_batch)
|
| 77 |
+
|
| 78 |
+
# apply a mask to e_current to cope with the fact that we do not keep the whole v_matrix and continue
|
| 79 |
+
# calculating viterbi probabilities during some 'padding' timesteps
|
| 80 |
+
t_exceeded_T_batch = t >= T_batch
|
| 81 |
+
|
| 82 |
+
U_can_be_final = torch.logical_or(
|
| 83 |
+
torch.arange(0, U_max, device=viterbi_device).unsqueeze(0) == (U_batch.unsqueeze(1) - 0),
|
| 84 |
+
torch.arange(0, U_max, device=viterbi_device).unsqueeze(0) == (U_batch.unsqueeze(1) - 1),
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
mask = torch.logical_not(torch.logical_and(t_exceeded_T_batch.unsqueeze(1), U_can_be_final,)).long()
|
| 88 |
+
|
| 89 |
+
e_current = e_current * mask
|
| 90 |
+
|
| 91 |
+
# v_prev_shifted is a tensor of shape (B, U_max) of the viterbi probabilities 1 timestep back and 1 token position back
|
| 92 |
+
v_prev_shifted = torch.roll(v_prev, shifts=1, dims=1)
|
| 93 |
+
# by doing a roll shift of size 1, we have brought the viterbi probability in the final token position to the
|
| 94 |
+
# first token position - let's overcome this by 'zeroing out' the probabilities in the firest token position
|
| 95 |
+
v_prev_shifted[:, 0] = V_NEGATIVE_NUM
|
| 96 |
+
|
| 97 |
+
# v_prev_shifted2 is a tensor of shape (B, U_max) of the viterbi probabilities 1 timestep back and 2 token position back
|
| 98 |
+
v_prev_shifted2 = torch.roll(v_prev, shifts=2, dims=1)
|
| 99 |
+
v_prev_shifted2[:, :2] = V_NEGATIVE_NUM # zero out as we did for v_prev_shifted
|
| 100 |
+
# use our letter_repetition_mask to remove the connections between 2 blanks (so we don't skip over a letter)
|
| 101 |
+
# and to remove the connections between 2 consective letters (so we don't skip over a blank)
|
| 102 |
+
v_prev_shifted2.masked_fill_(letter_repetition_mask, V_NEGATIVE_NUM)
|
| 103 |
+
|
| 104 |
+
# we need this v_prev_dup tensor so we can calculated the viterbi probability of every possible
|
| 105 |
+
# token position simultaneously
|
| 106 |
+
v_prev_dup = torch.cat(
|
| 107 |
+
(v_prev.unsqueeze(2), v_prev_shifted.unsqueeze(2), v_prev_shifted2.unsqueeze(2),), dim=2,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# candidates_v_current are our candidate viterbi probabilities for every token position, from which
|
| 111 |
+
# we will pick the max and record the argmax
|
| 112 |
+
candidates_v_current = v_prev_dup + e_current.unsqueeze(2)
|
| 113 |
+
# we straight away save results in v_prev instead of v_current, so that the variable v_prev will be ready for the
|
| 114 |
+
# next iteration of the for-loop
|
| 115 |
+
v_prev, bp_relative = torch.max(candidates_v_current, dim=2)
|
| 116 |
+
|
| 117 |
+
backpointers_rel[:, t, :] = bp_relative
|
| 118 |
+
|
| 119 |
+
# trace backpointers
|
| 120 |
+
alignments_batch = []
|
| 121 |
+
for b in range(B):
|
| 122 |
+
T_b = int(T_batch[b])
|
| 123 |
+
U_b = int(U_batch[b])
|
| 124 |
+
|
| 125 |
+
if U_b == 1: # i.e. we put only a blank token in the reference text because the reference text is empty
|
| 126 |
+
current_u = 0 # set initial u to 0 and let the rest of the code block run as usual
|
| 127 |
+
else:
|
| 128 |
+
current_u = int(torch.argmax(v_prev[b, U_b - 2 : U_b])) + U_b - 2
|
| 129 |
+
alignment_b = [current_u]
|
| 130 |
+
for t in range(T_max - 1, 0, -1):
|
| 131 |
+
current_u = current_u - int(backpointers_rel[b, t, current_u])
|
| 132 |
+
alignment_b.insert(0, current_u)
|
| 133 |
+
alignment_b = alignment_b[:T_b]
|
| 134 |
+
alignments_batch.append(alignment_b)
|
| 135 |
+
|
| 136 |
+
return alignments_batch
|
| 137 |
+
|