pere commited on
Commit
7f3bd04
·
verified ·
1 Parent(s): d4cf8c0

Update run_distillation.py

Browse files
Files changed (1) hide show
  1. run_distillation.py +28 -12
run_distillation.py CHANGED
@@ -1141,12 +1141,14 @@ def main():
1141
  if whisper_transcript is not None and whisper_transcript.upper() == whisper_transcript:
1142
  # filter entirely upper-case transcriptions: these are erroneous generations from large-v3
1143
  return False
1144
- elif len(norm_ground_truth) > 0 and whisper_transcript is not None:
 
 
1145
  norm_whisper_transcript = normalizer(whisper_transcript)
1146
  wer = 100 * metric.compute(predictions=[norm_whisper_transcript], references=[norm_ground_truth])
1147
  return wer < wer_threshold
1148
  else:
1149
- # filter automatically since we can't know the WER
1150
  return False
1151
 
1152
  filter_by_wer_threshold = partial(
@@ -1327,16 +1329,30 @@ def main():
1327
  label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
1328
  wer_ortho = 100 * metric.compute(predictions=pred_str, references=label_str)
1329
 
1330
- # normalize everything and re-compute the WER
1331
- norm_pred_str = [normalizer(pred) for pred in pred_str]
1332
- norm_label_str = [normalizer(label) for label in label_str]
1333
- # for logging, we need the pred/labels to match the norm_pred/norm_labels, so discard any filtered samples here
1334
- pred_str = [pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
1335
- label_str = [label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
1336
- # filtering step to only evaluate the samples that correspond to non-zero normalized references:
1337
- norm_pred_str = [norm_pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
1338
- norm_label_str = [norm_label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
1339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1340
  wer = 100 * metric.compute(predictions=norm_pred_str, references=norm_label_str)
1341
  return {"wer": wer, "wer_ortho": wer_ortho}, pred_str, label_str, norm_pred_str, norm_label_str
1342
 
@@ -1808,4 +1824,4 @@ def main():
1808
 
1809
 
1810
  if __name__ == "__main__":
1811
- main()
 
1141
  if whisper_transcript is not None and whisper_transcript.upper() == whisper_transcript:
1142
  # filter entirely upper-case transcriptions: these are erroneous generations from large-v3
1143
  return False
1144
+ elif len(norm_ground_truth) == 0 and len(normalizer(whisper_transcript)) == 0:
1145
+ return True
1146
+ elif len(norm_ground_truth.strip()) > 0 and whisper_transcript is not None and len(normalizer(whisper_transcript).strip()) > 0:
1147
  norm_whisper_transcript = normalizer(whisper_transcript)
1148
  wer = 100 * metric.compute(predictions=[norm_whisper_transcript], references=[norm_ground_truth])
1149
  return wer < wer_threshold
1150
  else:
1151
+ # filter automatically since weR
1152
  return False
1153
 
1154
  filter_by_wer_threshold = partial(
 
1329
  label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
1330
  wer_ortho = 100 * metric.compute(predictions=pred_str, references=label_str)
1331
 
1332
+ # Normalize everything
1333
+ norm_pred_str = []
1334
+ norm_label_str = []
 
 
 
 
 
 
1335
 
1336
+ # Iterate through all predictions and labels
1337
+ for pred, label in zip(pred_str, label_str):
1338
+ # Normalize the prediction and label
1339
+ normalized_pred = normalizer(pred)
1340
+ normalized_label = normalizer(label)
1341
+
1342
+ # If either normalized string is empty after normalization, replace with "<|nocaptions|>"
1343
+ if not normalized_pred.strip():
1344
+ normalized_pred = "<|nocaptions|>"
1345
+ if not normalized_label.strip():
1346
+ normalized_label = "<|nocaptions|>"
1347
+
1348
+ norm_pred_str.append(normalized_pred)
1349
+ norm_label_str.append(normalized_label)
1350
+
1351
+ # Replace original strings with "<|nocaptions|>" where necessary for consistency
1352
+ pred_str = [pred if len(pred.strip()) > 0 else "<|nocaptions|>" for pred in pred_str]
1353
+ label_str = [label if len(label.strip()) > 0 else "<|nocaptions|>" for label in label_str]
1354
+
1355
+ # Compute WER using all entries, including those with "<|nocaptions|>"
1356
  wer = 100 * metric.compute(predictions=norm_pred_str, references=norm_label_str)
1357
  return {"wer": wer, "wer_ortho": wer_ortho}, pred_str, label_str, norm_pred_str, norm_label_str
1358
 
 
1824
 
1825
 
1826
  if __name__ == "__main__":
1827
+ main()