mskov commited on
Commit
176ad20
Β·
1 Parent(s): 182346b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -26
app.py CHANGED
@@ -22,32 +22,29 @@ dataset = load_dataset("mskov/miso_test", split="test").cast_column("audio", Aud
22
 
23
  print(dataset, "and at 0[audio][array] ", dataset[0]["audio"]["array"], type(dataset[0]["audio"]["array"]), "and at audio : ", dataset[0]["audio"])
24
 
25
- test = evalWhisper(model, dataset)
26
- print("test ", test)
27
-
28
- def evalWhisper(model, dataset):
29
- model.eval()
30
- print("model.eval ", model.eval())
31
-
32
- # Evaluate the model
33
- model.eval()
34
- print("model.eval ", model.eval())
35
- with torch.no_grad():
36
- outputs = model(input_ids=input_ids, attention_mask=attention_mask)
37
- print("outputs ", outputs)
38
-
39
- # Convert predicted token IDs back to text
40
- predicted_text = tokenizer.batch_decode(outputs.logits.argmax(dim=-1), skip_special_tokens=True)
41
-
42
- # Get ground truth labels from the dataset
43
- labels = dataset["audio"] # Replace "labels" with the appropriate key in your dataset
44
- print("labels are ", labels)
45
-
46
- # Compute WER
47
- wer_score = wer(labels, predicted_text)
48
-
49
- # Print or return WER score
50
- print(f"Word Error Rate (WER): {wer_score}")
51
 
52
 
53
  def transcribe(audio):
 
22
 
23
  print(dataset, "and at 0[audio][array] ", dataset[0]["audio"]["array"], type(dataset[0]["audio"]["array"]), "and at audio : ", dataset[0]["audio"])
24
 
25
+
26
+ model.eval()
27
+ print("model.eval ", model.eval())
28
+
29
+ # Evaluate the model
30
+ model.eval()
31
+ print("model.eval ", model.eval())
32
+ with torch.no_grad():
33
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
34
+ print("outputs ", outputs)
35
+
36
+ # Convert predicted token IDs back to text
37
+ predicted_text = tokenizer.batch_decode(outputs.logits.argmax(dim=-1), skip_special_tokens=True)
38
+
39
+ # Get ground truth labels from the dataset
40
+ labels = dataset["audio"] # Replace "labels" with the appropriate key in your dataset
41
+ print("labels are ", labels)
42
+
43
+ # Compute WER
44
+ wer_score = wer(labels, predicted_text)
45
+
46
+ # Print or return WER score
47
+ print(f"Word Error Rate (WER): {wer_score}")
 
 
 
48
 
49
 
50
  def transcribe(audio):