File size: 7,258 Bytes
c531109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9502fe
 
 
 
 
c531109
e9502fe
 
c531109
e9502fe
 
 
c531109
e9502fe
c531109
e9502fe
c531109
e9502fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8fd411
e9502fe
 
 
 
 
 
f8fd411
 
e9502fe
 
 
 
f8fd411
 
 
 
e9502fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8fd411
 
e9502fe
b3ebd4c
e9502fe
b3ebd4c
e9502fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3ebd4c
e9502fe
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196

import torch
from datasets import load_dataset, DatasetDict
from datasets import Audio

from transformers import WhisperFeatureExtractor
from transformers import WhisperTokenizer
from transformers import WhisperProcessor
from transformers import WhisperForConditionalGeneration

from transformers import Seq2SeqTrainingArguments
from transformers import Seq2SeqTrainer

from dataclasses import dataclass
from typing import Any, Dict, List, Union
import evaluate


# Functions
# Define a Data Collator
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]}
                          for feature in features]
        batch = self.processor.feature_extractor.pad(
            input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]}
                          for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(
            label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(
            labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

def main():
    # Metrics
    def compute_metrics(pred):
        pred_ids = pred.predictions
        label_ids = pred.label_ids

        # replace -100 with the pad_token_id
        label_ids[label_ids == -100] = tokenizer.pad_token_id

        # we do not want to group tokens when computing the metrics
        pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
        label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

        wer = 100 * metric.compute(predictions=pred_str, references=label_str)

        return {"wer": wer}

    # Prepare dataset


    def prepare_dataset(batch):
        # load and resample audio data from 48 to 16kHz
        audio = batch["audio"]

        # compute log-Mel input features from input audio array
        batch["input_features"] = feature_extractor(
            audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

        # encode target text to label ids
        batch["labels"] = tokenizer(batch["sentence"]).input_ids
        return batch


    # Whisper Trainin Script

    # Map the source and target columns
    # Whisper expects these to be "audio" and "sentence". Change if anything else in the dataset
    source = "audio"
    target = "sentence"


    # Load a sample dataset
    speech_data = DatasetDict()

    # Examples
    #speech_data["train"] = load_dataset("NbAiLab/NPSC", "16K_mp3_bokmaal",  split="train", use_auth_token=True)
    #speech_data["test"] = load_dataset("NbAiLab/NPSC", "16K_mp3_bokmaal",  split="test", use_auth_token=True)
    # speech_data["train"] = load_dataset("NbAiLab/LIA_speech", split="train", use_auth_token=True)
    #speech_data["test"] = load_dataset("NbAiLab/LIA_speech", split="test", use_auth_token=True)

    # The smallest dataset I found
    speech_data["train"] = load_dataset(
        "mozilla-foundation/common_voice_11_0", "nn-NO", split="train", use_auth_token=True)
    speech_data["test"] = load_dataset(
        "mozilla-foundation/common_voice_11_0", "nn-NO", split="test", use_auth_token=True)


    # Rename columns
    if "audio" not in speech_data.column_names["train"]:
        speech_data = speech_data.rename_column(source, "audio")

    if "sentence" not in speech_data.column_names["train"]:
        speech_data = speech_data.rename_column(target, "sentence")

    # Remove not needed columns - Not really sure if this is necessary
    remove_list = [i for i in speech_data.column_names["train"]
                if i not in ["audio", "sentence"]]

    speech_data = speech_data.remove_columns(remove_list)

    # Initialise
    feature_extractor = WhisperFeatureExtractor.from_pretrained(
        "openai/whisper-small")
    tokenizer = WhisperTokenizer.from_pretrained(
        "openai/whisper-small", language="Norwegian", task="transcribe")
    processor = WhisperProcessor.from_pretrained(
        "openai/whisper-small", language="Norwegian", task="transcribe")
    data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

    # Prepare data
    speech_data = speech_data.cast_column("audio", Audio(sampling_rate=16000))
    speech_data = speech_data.map(
        prepare_dataset, remove_columns=speech_data.column_names["train"], num_proc=1)

    # Metrics
    metric = evaluate.load("wer")

    # Initialise a Pretrained model
    # We need to set use_cache=False here if we want to use gradient accumulation
    model = WhisperForConditionalGeneration.from_pretrained(
        "openai/whisper-small", use_cache=False)

    # Overriding generation arguments - no tokens are forced as decoder outputs (see [`forced_decoder_ids`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.forced_decoder_ids)), no tokens are suppressed during generation (see [`suppress_tokens`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.suppress_tokens)):
    model.config.forced_decoder_ids = None
    model.config.suppress_tokens = []

    # Training arguments
    training_args = Seq2SeqTrainingArguments(
        output_dir="../whisper-testrun1",  # change to a repo name of your choice
        per_device_train_batch_size=16,
        gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
        learning_rate=2e-5,
        warmup_steps=500,
        max_steps=5000,  # Changed from 4000
        gradient_checkpointing=True,
        group_by_length=True,
        evaluation_strategy="steps",
        per_device_eval_batch_size=8,
        predict_with_generate=True,
        generation_max_length=225,
        save_steps=500,
        eval_steps=500,
        logging_steps=25,
        report_to=["tensorboard"],
        load_best_model_at_end=True,
        metric_for_best_model="wer",
        greater_is_better=False,
        push_to_hub=True,
    )

    trainer = Seq2SeqTrainer(
        args=training_args,
        model=model,
        train_dataset=speech_data["train"],
        eval_dataset=speech_data["test"],
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        tokenizer=processor.feature_extractor,
    )


    # Start training
    trainer.train()


def _mp_fn(index):
    # For xla_spawn (TPUs)
    print("The XLA is initiated")
    main()


if __name__ == "__main__":
    main()