pere commited on
Commit
66d5aa3
·
1 Parent(s): 0be47bc

train code

Browse files
Files changed (3) hide show
  1. run.sh +36 -1
  2. run_speech_recognition_whisper_pere.py +81 -123
  3. run_xla.sh +4 -0
run.sh CHANGED
@@ -1,4 +1,39 @@
1
 
2
- python xla_spawn.py --num_cores=4 run_whisper.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
 
 
 
1
 
2
+ python run_speech_recognition_whisper_pere.py \
3
+ --model_name_or_path="openai/whisper-small" \
4
+ --output_dir="../whisper-testrun1" \
5
+ --overwrite_output_dir=True \
6
+ --language="Norwegian" \
7
+ --task="transcribe" \
8
+ --dataset_name="mozilla-foundation/common_voice_11_0" \
9
+ --dataset_config="nn-NO" \
10
+ --output_dir="./whisper-small-hi" \
11
+ --do_train=True \
12
+ --do_eval=True \
13
+ --audio_column_name="audio" \
14
+ --text_column_name="sentence" \
15
+ --per_device_train_batch_size=16 \
16
+ --per_device_train_batch_size=16 \
17
+ --learning_rate=2e-5 \
18
+ --warmup_steps=500 \
19
+ --max_steps=5000 \
20
+ --gradient_checkpointing=True \
21
+ --gradient_accumulation_steps=1 \
22
+ --group_by_length=True \
23
+ --evaluation_strategy="steps" \
24
+ --save_steps=1000 \
25
+ --eval_steps=1000 \
26
+ --logging_steps=25 \
27
+ --fp16=True \
28
+ --save_steps=1000 \
29
+ --load_best_model_at_end=True \
30
+ --metric_for_best_model="wer" \
31
+ --greater_is_better=False \
32
+ --report_to="tensorboard" \
33
+ --predict_with_generate=True \
34
+ --generation_max_length=225 \
35
+ --print_training_arguments=True \
36
+ --push_to_hub=True
37
 
38
 
39
+
run_speech_recognition_whisper_pere.py CHANGED
@@ -22,27 +22,20 @@ import re
22
  import sys
23
  import warnings
24
  from dataclasses import dataclass, field
25
- from typing import Dict, List, Optional, Union
 
26
 
27
- import datasets
28
  import numpy as np
29
  import torch
 
30
  import evaluate
31
  from datasets import DatasetDict, load_dataset
 
32
 
33
- import transformers
34
  from transformers import (
35
- AutoConfig,
36
- AutoFeatureExtractor,
37
- AutoModelForCTC,
38
- AutoProcessor,
39
- AutoTokenizer,
40
  HfArgumentParser,
41
- Trainer,
42
  TrainingArguments,
43
- Wav2Vec2Processor,
44
  set_seed,
45
-
46
  WhisperFeatureExtractor,
47
  WhisperTokenizer,
48
  WhisperForConditionalGeneration,
@@ -54,18 +47,9 @@ from transformers.trainer_utils import get_last_checkpoint, is_main_process
54
  from transformers.utils import check_min_version
55
  from transformers.utils.versions import require_version
56
 
57
- # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
58
- # check_min_version("4.24.0.dev0")
59
-
60
- # require_version("datasets>=2.6.1", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
61
-
62
- logger = logging.getLogger(__name__)
63
-
64
-
65
  def list_field(default=None, metadata=None):
66
  return field(default_factory=lambda: default, metadata=metadata)
67
 
68
-
69
  @dataclass
70
  class ModelArguments:
71
  """
@@ -243,6 +227,14 @@ class DataTrainingArguments:
243
  default="|",
244
  metadata={"help": "The word delimiter token for the tokenizer"},
245
  )
 
 
 
 
 
 
 
 
246
  phoneme_language: Optional[str] = field(
247
  default=None,
248
  metadata={
@@ -252,6 +244,12 @@ class DataTrainingArguments:
252
  " input audio to a sequence of phoneme sequences."
253
  },
254
  )
 
 
 
 
 
 
255
 
256
 
257
  @dataclass
@@ -278,20 +276,17 @@ class DataCollatorSpeechSeq2SeqWithPadding:
278
  labels = labels[:, 1:]
279
 
280
  batch["labels"] = labels
281
-
282
  return batch
283
 
284
 
285
-
286
-
287
  def main():
288
  # See all possible arguments in src/transformers/training_args.py
289
  # or by passing the --help flag to this script.
290
  # We now keep distinct sets of args, for a cleaner separation of concerns.
291
  parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
292
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
 
293
 
294
-
295
  # Metrics
296
  def compute_metrics(pred):
297
  pred_ids = pred.predictions
@@ -309,8 +304,6 @@ def main():
309
  return {"wer": wer}
310
 
311
  # Prepare dataset
312
-
313
-
314
  def prepare_dataset(batch):
315
  # load and resample audio data from 48 to 16kHz
316
  audio = batch["audio"]
@@ -323,45 +316,58 @@ def main():
323
  batch["labels"] = tokenizer(batch["sentence"]).input_ids
324
  return batch
325
 
326
- def make_dataset(training_args, data_args):
327
- seed = training_args.seed or 42
328
- dataset = datasets.load_dataset(training_args.dataset_name, training_args.dataset_config_name, use_auth_token=data_args.use_auth_token)
329
- return dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
 
331
- # PERE - SHOULD BE CHANGED TO STREAMING LATER
332
  # Load dataset
333
  speech_data = DatasetDict()
334
-
335
- # The smallest dataset I found
336
  speech_data["train"] = load_dataset(
337
- "mozilla-foundation/common_voice_11_0", "nn-NO", split="train", use_auth_token=True)
338
  speech_data["test"] = load_dataset(
339
- "mozilla-foundation/common_voice_11_0", "nn-NO", split="test", use_auth_token=True)
340
 
341
- # PERE - REPLACE WITH THIS
342
  # speech_data = make_dataset(training_args, data_args)
343
 
344
- # Rename columns
 
 
345
  if "audio" not in speech_data.column_names["train"]:
346
  speech_data = speech_data.rename_column(source, "audio")
347
 
348
  if "sentence" not in speech_data.column_names["train"]:
349
  speech_data = speech_data.rename_column(target, "sentence")
350
 
351
- # Remove not needed columns - Not really sure if this is necessary
352
  remove_list = [i for i in speech_data.column_names["train"]
353
  if i not in ["audio", "sentence"]]
354
 
355
  speech_data = speech_data.remove_columns(remove_list)
356
 
357
- # PERE - NEEDS TO BE PARAMETERIZED
358
  # Initialise
359
  feature_extractor = WhisperFeatureExtractor.from_pretrained(
360
- "openai/whisper-small")
361
  tokenizer = WhisperTokenizer.from_pretrained(
362
- "openai/whisper-small", language="Norwegian", task="transcribe")
363
  processor = WhisperProcessor.from_pretrained(
364
- "openai/whisper-small", language="Norwegian", task="transcribe")
365
  data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
366
 
367
  # Prepare data
@@ -369,6 +375,8 @@ def main():
369
  speech_data = speech_data.map(
370
  prepare_dataset, remove_columns=speech_data.column_names["train"], num_proc=1)
371
 
 
 
372
  # Metrics
373
  metric = evaluate.load("wer")
374
 
@@ -395,88 +403,47 @@ def main():
395
  if last_checkpoint is not None:
396
  checkpoint = last_checkpoint
397
  elif os.path.isdir(model_args.model_name_or_path):
398
- checkpoint = model_args.model_name_or_path
399
- # Initialise a Pretrained model
400
- # We need to set use_cache=False here if we want to use gradient accumulation
401
- # PERE - For the test this is set static
402
-
403
- model = WhisperForConditionalGeneration.from_pretrained(
404
- "openai/whisper-small", use_cache=False)
405
-
406
  else:
407
  checkpoint = None
 
 
 
 
408
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
  train_result = trainer.train(resume_from_checkpoint=checkpoint)
410
  trainer.save_model()
411
 
412
  metrics = train_result.metrics
413
- max_train_samples = (
414
- data_args.max_train_samples
415
- if data_args.max_train_samples is not None
416
- else len(vectorized_datasets["train"])
417
- )
418
- metrics["train_samples"] = min(max_train_samples, len(vectorized_datasets["train"]))
419
-
420
  trainer.log_metrics("train", metrics)
421
  trainer.save_metrics("train", metrics)
422
  trainer.save_state()
423
-
424
- # Overriding generation arguments - no tokens are forced as decoder outputs (see [`forced_decoder_ids`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.forced_decoder_ids)), no tokens are suppressed during generation (see [`suppress_tokens`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.suppress_tokens)):
425
- model.config.forced_decoder_ids = None
426
- model.config.suppress_tokens = []
427
-
428
- # Set seed before initializing model.
429
- set_seed(training_args.seed)
430
-
431
- # Training arguments
432
- training_args = Seq2SeqTrainingArguments(
433
- output_dir="../whisper-testrun1", # change to a repo name of your choice
434
- per_device_train_batch_size=16,
435
- gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
436
- learning_rate=2e-5,
437
- warmup_steps=500,
438
- max_steps=5000, # Changed from 4000
439
- gradient_checkpointing=True,
440
- group_by_length=True,
441
- evaluation_strategy="steps",
442
- per_device_eval_batch_size=8,
443
- predict_with_generate=True,
444
- generation_max_length=225,
445
- save_steps=500,
446
- eval_steps=500,
447
- logging_steps=25,
448
- report_to=["tensorboard"],
449
- load_best_model_at_end=True,
450
- metric_for_best_model="wer",
451
- greater_is_better=False,
452
- push_to_hub=True,
453
- )
454
-
455
- trainer = Seq2SeqTrainer(
456
- args=training_args,
457
- model=model,
458
- train_dataset=speech_data["train"],
459
- eval_dataset=speech_data["test"],
460
- data_collator=data_collator,
461
- compute_metrics=compute_metrics,
462
- tokenizer=processor.feature_extractor,
463
- )
464
-
465
-
466
- # Initialize Trainer
467
- trainer = Seq2SeqTrainer(
468
- model=model,
469
- data_collator=data_collator,
470
- args=training_args,
471
- compute_metrics=compute_metrics,
472
- train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
473
- eval_dataset=vectorized_datasets["validation"] if training_args.do_eval else None,
474
- tokenizer=feature_extractor,
475
- )
476
-
477
- # 8. Finally, we can start training
478
-
479
 
 
 
480
  # Evaluation
481
  results = {}
482
  if training_args.do_eval:
@@ -500,14 +467,7 @@ def main():
500
  "dataset": f"{data_args.dataset_name.upper()} - {config_name.upper()}",
501
  "language": model_args.language,
502
  }
503
- if "common_voice" in data_args.dataset_name:
504
- kwargs["language"] = config_name
505
-
506
- if training_args.push_to_hub:
507
- trainer.push_to_hub(**kwargs)
508
- else:
509
- trainer.create_model_card(**kwargs)
510
-
511
  return results
512
 
513
 
@@ -517,7 +477,5 @@ def _mp_fn(index):
517
  print("The XLA is initiated")
518
  main()
519
 
520
-
521
-
522
  if __name__ == "__main__":
523
  main()
 
22
  import sys
23
  import warnings
24
  from dataclasses import dataclass, field
25
+ from typing import Any, Dict, List, Optional,Union
26
+ import evaluate
27
 
 
28
  import numpy as np
29
  import torch
30
+ from pprint import pprint
31
  import evaluate
32
  from datasets import DatasetDict, load_dataset
33
+ from datasets import Audio
34
 
 
35
  from transformers import (
 
 
 
 
 
36
  HfArgumentParser,
 
37
  TrainingArguments,
 
38
  set_seed,
 
39
  WhisperFeatureExtractor,
40
  WhisperTokenizer,
41
  WhisperForConditionalGeneration,
 
47
  from transformers.utils import check_min_version
48
  from transformers.utils.versions import require_version
49
 
 
 
 
 
 
 
 
 
50
  def list_field(default=None, metadata=None):
51
  return field(default_factory=lambda: default, metadata=metadata)
52
 
 
53
  @dataclass
54
  class ModelArguments:
55
  """
 
227
  default="|",
228
  metadata={"help": "The word delimiter token for the tokenizer"},
229
  )
230
+ predict_with_generate: bool = field(
231
+ default=True,
232
+ metadata={"help": "Output tokens in addition to loss and digits for calculating metrics"},
233
+ )
234
+ generation_max_length: int = field(
235
+ default=225,
236
+ metadata={"help": "Maximum number of tokens generated"},
237
+ )
238
  phoneme_language: Optional[str] = field(
239
  default=None,
240
  metadata={
 
244
  " input audio to a sequence of phoneme sequences."
245
  },
246
  )
247
+ print_training_arguments: bool = field(
248
+ default=True,
249
+ metadata={
250
+ "help": "Prints the training arguments. For debugging"
251
+ },
252
+ )
253
 
254
 
255
  @dataclass
 
276
  labels = labels[:, 1:]
277
 
278
  batch["labels"] = labels
 
279
  return batch
280
 
281
 
 
 
282
  def main():
283
  # See all possible arguments in src/transformers/training_args.py
284
  # or by passing the --help flag to this script.
285
  # We now keep distinct sets of args, for a cleaner separation of concerns.
286
  parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
287
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
288
+
289
 
 
290
  # Metrics
291
  def compute_metrics(pred):
292
  pred_ids = pred.predictions
 
304
  return {"wer": wer}
305
 
306
  # Prepare dataset
 
 
307
  def prepare_dataset(batch):
308
  # load and resample audio data from 48 to 16kHz
309
  audio = batch["audio"]
 
316
  batch["labels"] = tokenizer(batch["sentence"]).input_ids
317
  return batch
318
 
319
+ def print_training_arguments(model_args, data_args, training_args):
320
+ print("Starting with the following parameters:")
321
+ print("\n* Model arguments:")
322
+ pprint(vars(model_args),indent=2)
323
+ print("\n* Data arguments")
324
+ pprint(vars(data_args),indent=2)
325
+ print("\n* Training arguments")
326
+ pprint(vars(training_args),indent=2)
327
+
328
+ # TODO - Might use this function later
329
+ # def make_dataset(training_args, data_args):
330
+ # seed = training_args.seed or 42
331
+ # dataset = datasets.load_dataset(training_args.dataset_name, training_args.dataset_config_name, use_auth_token=data_args.use_auth_token)
332
+ # return dataset
333
+
334
+ # Print training arguments
335
+ if data_args.print_training_arguments:
336
+ print_training_arguments(model_args, data_args, training_args)
337
 
 
338
  # Load dataset
339
  speech_data = DatasetDict()
 
 
340
  speech_data["train"] = load_dataset(
341
+ data_args.dataset_name, data_args.dataset_config_name, split="train", use_auth_token=True)
342
  speech_data["test"] = load_dataset(
343
+ data_args.dataset_name, data_args.dataset_config_name, split="test", use_auth_token=True)
344
 
345
+ # TODO - Implement streaming and include this
346
  # speech_data = make_dataset(training_args, data_args)
347
 
348
+
349
+ # Adapt dataset - Change column names and delete extra data
350
+ # Map columns
351
  if "audio" not in speech_data.column_names["train"]:
352
  speech_data = speech_data.rename_column(source, "audio")
353
 
354
  if "sentence" not in speech_data.column_names["train"]:
355
  speech_data = speech_data.rename_column(target, "sentence")
356
 
357
+ # Remove not needed columns
358
  remove_list = [i for i in speech_data.column_names["train"]
359
  if i not in ["audio", "sentence"]]
360
 
361
  speech_data = speech_data.remove_columns(remove_list)
362
 
363
+
364
  # Initialise
365
  feature_extractor = WhisperFeatureExtractor.from_pretrained(
366
+ model_args.model_name_or_path)
367
  tokenizer = WhisperTokenizer.from_pretrained(
368
+ model_args.model_name_or_path, language=model_args.language, task=model_args.task)
369
  processor = WhisperProcessor.from_pretrained(
370
+ model_args.model_name_or_path, language=model_args.language, task=model_args.task)
371
  data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
372
 
373
  # Prepare data
 
375
  speech_data = speech_data.map(
376
  prepare_dataset, remove_columns=speech_data.column_names["train"], num_proc=1)
377
 
378
+
379
+
380
  # Metrics
381
  metric = evaluate.load("wer")
382
 
 
403
  if last_checkpoint is not None:
404
  checkpoint = last_checkpoint
405
  elif os.path.isdir(model_args.model_name_or_path):
406
+ checkpoint = model_args.model_name_or_path
 
 
 
 
 
 
 
407
  else:
408
  checkpoint = None
409
+
410
+ # We need to set use_cache=False here if we want to use gradient accumulation
411
+ model = WhisperForConditionalGeneration.from_pretrained(
412
+ "openai/whisper-small", use_cache=False)
413
 
414
+ # Overriding generation arguments - no tokens are forced as decoder outputs (see [`forced_decoder_ids`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.forced_decoder_ids)), no tokens are suppressed during generation (see [`suppress_tokens`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.suppress_tokens)):
415
+ model.config.forced_decoder_ids = None
416
+ model.config.suppress_tokens = []
417
+
418
+ # Set seed before initializing model.
419
+ set_seed(training_args.seed)
420
+
421
+
422
+ trainer = Seq2SeqTrainer(
423
+ args=training_args,
424
+ model=model,
425
+ train_dataset=speech_data["train"],
426
+ eval_dataset=speech_data["test"],
427
+ data_collator=data_collator,
428
+ compute_metrics=compute_metrics,
429
+ tokenizer=processor.feature_extractor,
430
+ )
431
+
432
  train_result = trainer.train(resume_from_checkpoint=checkpoint)
433
  trainer.save_model()
434
 
435
  metrics = train_result.metrics
 
 
 
 
 
 
 
436
  trainer.log_metrics("train", metrics)
437
  trainer.save_metrics("train", metrics)
438
  trainer.save_state()
439
+
440
+ if training_args.push_to_hub:
441
+ trainer.push_to_hub(**kwargs)
442
+ else:
443
+ trainer.create_model_card(**kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
444
 
445
+ # TODO - Look closer into the evaluation and the model card writing.
446
+
447
  # Evaluation
448
  results = {}
449
  if training_args.do_eval:
 
467
  "dataset": f"{data_args.dataset_name.upper()} - {config_name.upper()}",
468
  "language": model_args.language,
469
  }
470
+
 
 
 
 
 
 
 
471
  return results
472
 
473
 
 
477
  print("The XLA is initiated")
478
  main()
479
 
 
 
480
  if __name__ == "__main__":
481
  main()
run_xla.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+
2
+ python xla_spawn.py --num_cores=4 run_whisper.py
3
+
4
+