File size: 3,507 Bytes
de07127
 
50f4595
4afec78
b4725a8
decc59e
3bef3fb
664eb76
0b5b7f4
d49d800
8a965da
a2384ba
8a965da
3b57b43
3826e01
973bb39
62ac43e
037afb0
47fa28f
d49d800
3bef3fb
db75012
3ce5824
2ec9293
e560eb6
2ec9293
 
 
 
 
 
 
 
 
 
 
 
75a25c7
e560eb6
47837cb
e560eb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d49d800
e560eb6
 
 
3bef3fb
e560eb6
 
3bef3fb
e560eb6
 
 
 
 
 
 
 
 
3bef3fb
d49d800
e560eb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3bef3fb
75a25c7
 
 
 
 
 
 
3bef3fb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os 
import sys 
os.system("pip install transformers==4.27.0")
os.system("pip install numpy==1.23")
from transformers import pipeline, WhisperModel, WhisperTokenizer, WhisperFeatureExtractor, AutoFeatureExtractor, AutoProcessor, WhisperConfig
os.system("pip install jiwer")
from jiwer import wer
os.system("pip install datasets[audio]")
from evaluate import evaluator
import evaluate
from datasets import load_dataset, Audio, disable_caching, set_caching_enabled
import gradio as gr

set_caching_enabled(False)
disable_caching()

huggingface_token = os.environ["huggingface_token"]
pipe = pipeline(model="mskov/whisper-small-esc50") 
print(pipe)

dataset = load_dataset("mskov/miso_test", split="test").cast_column("audio", Audio(sampling_rate=16000))

print(dataset, "and at 0[audio][array] ", dataset[0]["audio"]["array"], type(dataset[0]["audio"]["array"]), "and at audio : ", dataset[0]["audio"])


def transcribe(audio):
    text = pipe(audio)["text"]
    return text

iface = gr.Interface(
    fn=transcribe, 
    inputs=gr.Audio(source="microphone", type="filepath"), 
    outputs="text",
    title="Whisper Small Miso Test",
)

iface.launch()

def evalWhisper(model, dataset):
    model.eval()
    print("model.eval ", model.eval())
    
    # Define a list to store the print statements
    log_texts = []
    
    with torch.no_grad():
        outputs = model(**input_data)  # Define input_data appropriately
        print("outputs ", outputs)
        log_texts.append(f"outputs: {outputs}")
    
    # Convert predicted token IDs back to text
    predicted_text = tokenizer.batch_decode(outputs.logits.argmax(dim=-1), skip_special_tokens=True)
    
    # Get ground truth labels from the dataset
    labels = dataset["audio"]  # Replace "labels" with the appropriate key in your dataset
    print("labels are ", labels)
    log_texts.append(f"labels: {labels}")
    
    # Compute WER
    wer_score = wer(labels, predicted_text)  # Define wer function
    
    # Print or return WER score
    wer_message = f"Word Error Rate (WER): {wer_score}"
    print(wer_message)
    log_texts.append(wer_message)

    print(log_texts)
    
    return log_texts

# Call evalWhisper and get the log texts
log_texts = evalWhisper(model, dataset)

# Display the log texts using gr.Interface
log_text = "\n".join(log_texts)
log_interface = gr.Interface(
    fn=lambda: log_text,
    inputs=None,
    outputs="text",
    title="EvalWhisper Log",
)
log_interface.launch()


'''
    # Evaluate the model
    model.eval()
    print("model.eval ", model.eval())
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        print("outputs ", outputs)
    
    # Convert predicted token IDs back to text
    predicted_text = tokenizer.batch_decode(outputs.logits.argmax(dim=-1), skip_special_tokens=True)
    
    # Get ground truth labels from the dataset
    labels = dataset["audio"]  # Replace "labels" with the appropriate key in your dataset
    print("labels are ", labels)
    
    # Compute WER
    wer_score = wer(labels, predicted_text)
    
    # Print or return WER score
    print(f"Word Error Rate (WER): {wer_score}")
'''
'''
print("check check")
print(inputs)
input_features = inputs.input_features
decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
list(last_hidden_state.shape)
print(list(last_hidden_state.shape))
'''