SWivid commited on
Commit
bc63315
·
1 Parent(s): 423fe4a

split pkgs only for eval usage address #97; clean-up

Browse files
Files changed (5) hide show
  1. README.md +6 -0
  2. gradio_app.py +0 -2
  3. model/utils.py +6 -7
  4. requirements.txt +1 -8
  5. requirements_eval.txt +5 -0
README.md CHANGED
@@ -148,6 +148,12 @@ bash scripts/eval_infer_batch.sh
148
 
149
  ### Objective Evaluation
150
 
 
 
 
 
 
 
151
  **Some Notes**
152
 
153
  For faster-whisper with CUDA 11:
 
148
 
149
  ### Objective Evaluation
150
 
151
+ Install packages for evaluation:
152
+
153
+ ```bash
154
+ pip install -r requirements_eval.txt
155
+ ```
156
+
157
  **Some Notes**
158
 
159
  For faster-whisper with CUDA 11:
gradio_app.py CHANGED
@@ -1,4 +1,3 @@
1
- import os
2
  import re
3
  import torch
4
  import torchaudio
@@ -17,7 +16,6 @@ from model.utils import (
17
  save_spectrogram,
18
  )
19
  from transformers import pipeline
20
- import librosa
21
  import click
22
  import soundfile as sf
23
 
 
 
1
  import re
2
  import torch
3
  import torchaudio
 
16
  save_spectrogram,
17
  )
18
  from transformers import pipeline
 
19
  import click
20
  import soundfile as sf
21
 
model/utils.py CHANGED
@@ -22,12 +22,6 @@ from einops import rearrange, reduce
22
 
23
  import jieba
24
  from pypinyin import lazy_pinyin, Style
25
- import zhconv
26
- from zhon.hanzi import punctuation
27
- from jiwer import compute_measures
28
-
29
- from funasr import AutoModel
30
- from faster_whisper import WhisperModel
31
 
32
  from model.ecapa_tdnn import ECAPA_TDNN_SMALL
33
  from model.modules import MelSpec
@@ -432,6 +426,7 @@ def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path
432
 
433
  def load_asr_model(lang, ckpt_dir = ""):
434
  if lang == "zh":
 
435
  model = AutoModel(
436
  model = os.path.join(ckpt_dir, "paraformer-zh"),
437
  # vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
@@ -440,6 +435,7 @@ def load_asr_model(lang, ckpt_dir = ""):
440
  disable_update=True,
441
  ) # following seed-tts setting
442
  elif lang == "en":
 
443
  model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
444
  model = WhisperModel(model_size, device="cuda", compute_type="float16")
445
  return model
@@ -451,6 +447,7 @@ def run_asr_wer(args):
451
  rank, lang, test_set, ckpt_dir = args
452
 
453
  if lang == "zh":
 
454
  torch.cuda.set_device(rank)
455
  elif lang == "en":
456
  os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
@@ -458,10 +455,12 @@ def run_asr_wer(args):
458
  raise NotImplementedError("lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now.")
459
 
460
  asr_model = load_asr_model(lang, ckpt_dir = ckpt_dir)
461
-
 
462
  punctuation_all = punctuation + string.punctuation
463
  wers = []
464
 
 
465
  for gen_wav, prompt_wav, truth in tqdm(test_set):
466
  if lang == "zh":
467
  res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
 
22
 
23
  import jieba
24
  from pypinyin import lazy_pinyin, Style
 
 
 
 
 
 
25
 
26
  from model.ecapa_tdnn import ECAPA_TDNN_SMALL
27
  from model.modules import MelSpec
 
426
 
427
  def load_asr_model(lang, ckpt_dir = ""):
428
  if lang == "zh":
429
+ from funasr import AutoModel
430
  model = AutoModel(
431
  model = os.path.join(ckpt_dir, "paraformer-zh"),
432
  # vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
 
435
  disable_update=True,
436
  ) # following seed-tts setting
437
  elif lang == "en":
438
+ from faster_whisper import WhisperModel
439
  model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
440
  model = WhisperModel(model_size, device="cuda", compute_type="float16")
441
  return model
 
447
  rank, lang, test_set, ckpt_dir = args
448
 
449
  if lang == "zh":
450
+ import zhconv
451
  torch.cuda.set_device(rank)
452
  elif lang == "en":
453
  os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
 
455
  raise NotImplementedError("lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now.")
456
 
457
  asr_model = load_asr_model(lang, ckpt_dir = ckpt_dir)
458
+
459
+ from zhon.hanzi import punctuation
460
  punctuation_all = punctuation + string.punctuation
461
  wers = []
462
 
463
+ from jiwer import compute_measures
464
  for gen_wav, prompt_wav, truth in tqdm(test_set):
465
  if lang == "zh":
466
  res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
requirements.txt CHANGED
@@ -5,11 +5,8 @@ datasets
5
  einops>=0.8.0
6
  einx>=0.3.0
7
  ema_pytorch>=0.5.2
8
- faster_whisper
9
- funasr
10
  gradio
11
  jieba
12
- jiwer
13
  librosa
14
  matplotlib
15
  numpy<=1.26.4
@@ -17,14 +14,10 @@ pydub
17
  pypinyin
18
  safetensors
19
  soundfile
20
- # torch>=2.0
21
- # torchaudio>=2.3.0
22
  torchdiffeq
23
  tqdm>=4.65.0
24
  transformers
25
  vocos
26
  wandb
27
  x_transformers>=1.31.14
28
- zhconv
29
- zhon
30
- tomli
 
5
  einops>=0.8.0
6
  einx>=0.3.0
7
  ema_pytorch>=0.5.2
 
 
8
  gradio
9
  jieba
 
10
  librosa
11
  matplotlib
12
  numpy<=1.26.4
 
14
  pypinyin
15
  safetensors
16
  soundfile
17
+ tomli
 
18
  torchdiffeq
19
  tqdm>=4.65.0
20
  transformers
21
  vocos
22
  wandb
23
  x_transformers>=1.31.14
 
 
 
requirements_eval.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ faster_whisper
2
+ funasr
3
+ jiwer
4
+ zhconv
5
+ zhon