Update run_distillation.py
Browse files- run_distillation.py +28 -12
run_distillation.py
CHANGED
@@ -1141,12 +1141,14 @@ def main():
|
|
1141 |
if whisper_transcript is not None and whisper_transcript.upper() == whisper_transcript:
|
1142 |
# filter entirely upper-case transcriptions: these are erroneous generations from large-v3
|
1143 |
return False
|
1144 |
-
elif len(norm_ground_truth)
|
|
|
|
|
1145 |
norm_whisper_transcript = normalizer(whisper_transcript)
|
1146 |
wer = 100 * metric.compute(predictions=[norm_whisper_transcript], references=[norm_ground_truth])
|
1147 |
return wer < wer_threshold
|
1148 |
else:
|
1149 |
-
# filter automatically since
|
1150 |
return False
|
1151 |
|
1152 |
filter_by_wer_threshold = partial(
|
@@ -1327,16 +1329,30 @@ def main():
|
|
1327 |
label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
1328 |
wer_ortho = 100 * metric.compute(predictions=pred_str, references=label_str)
|
1329 |
|
1330 |
-
#
|
1331 |
-
norm_pred_str = [
|
1332 |
-
norm_label_str = [
|
1333 |
-
# for logging, we need the pred/labels to match the norm_pred/norm_labels, so discard any filtered samples here
|
1334 |
-
pred_str = [pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
|
1335 |
-
label_str = [label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
|
1336 |
-
# filtering step to only evaluate the samples that correspond to non-zero normalized references:
|
1337 |
-
norm_pred_str = [norm_pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
|
1338 |
-
norm_label_str = [norm_label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
|
1339 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1340 |
wer = 100 * metric.compute(predictions=norm_pred_str, references=norm_label_str)
|
1341 |
return {"wer": wer, "wer_ortho": wer_ortho}, pred_str, label_str, norm_pred_str, norm_label_str
|
1342 |
|
@@ -1808,4 +1824,4 @@ def main():
|
|
1808 |
|
1809 |
|
1810 |
if __name__ == "__main__":
|
1811 |
-
main()
|
|
|
1141 |
if whisper_transcript is not None and whisper_transcript.upper() == whisper_transcript:
|
1142 |
# filter entirely upper-case transcriptions: these are erroneous generations from large-v3
|
1143 |
return False
|
1144 |
+
elif len(norm_ground_truth) == 0 and len(normalizer(whisper_transcript)) == 0:
|
1145 |
+
return True
|
1146 |
+
elif len(norm_ground_truth.strip()) > 0 and whisper_transcript is not None and len(normalizer(whisper_transcript).strip()) > 0:
|
1147 |
norm_whisper_transcript = normalizer(whisper_transcript)
|
1148 |
wer = 100 * metric.compute(predictions=[norm_whisper_transcript], references=[norm_ground_truth])
|
1149 |
return wer < wer_threshold
|
1150 |
else:
|
1151 |
+
# filter automatically since weR
|
1152 |
return False
|
1153 |
|
1154 |
filter_by_wer_threshold = partial(
|
|
|
1329 |
label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
1330 |
wer_ortho = 100 * metric.compute(predictions=pred_str, references=label_str)
|
1331 |
|
1332 |
+
# Normalize everything
|
1333 |
+
norm_pred_str = []
|
1334 |
+
norm_label_str = []
|
|
|
|
|
|
|
|
|
|
|
|
|
1335 |
|
1336 |
+
# Iterate through all predictions and labels
|
1337 |
+
for pred, label in zip(pred_str, label_str):
|
1338 |
+
# Normalize the prediction and label
|
1339 |
+
normalized_pred = normalizer(pred)
|
1340 |
+
normalized_label = normalizer(label)
|
1341 |
+
|
1342 |
+
# If either normalized string is empty after normalization, replace with "<|nocaptions|>"
|
1343 |
+
if not normalized_pred.strip():
|
1344 |
+
normalized_pred = "<|nocaptions|>"
|
1345 |
+
if not normalized_label.strip():
|
1346 |
+
normalized_label = "<|nocaptions|>"
|
1347 |
+
|
1348 |
+
norm_pred_str.append(normalized_pred)
|
1349 |
+
norm_label_str.append(normalized_label)
|
1350 |
+
|
1351 |
+
# Replace original strings with "<|nocaptions|>" where necessary for consistency
|
1352 |
+
pred_str = [pred if len(pred.strip()) > 0 else "<|nocaptions|>" for pred in pred_str]
|
1353 |
+
label_str = [label if len(label.strip()) > 0 else "<|nocaptions|>" for label in label_str]
|
1354 |
+
|
1355 |
+
# Compute WER using all entries, including those with "<|nocaptions|>"
|
1356 |
wer = 100 * metric.compute(predictions=norm_pred_str, references=norm_label_str)
|
1357 |
return {"wer": wer, "wer_ortho": wer_ortho}, pred_str, label_str, norm_pred_str, norm_label_str
|
1358 |
|
|
|
1824 |
|
1825 |
|
1826 |
if __name__ == "__main__":
|
1827 |
+
main()
|