KDM999 commited on
Commit
8cc7c73
·
verified ·
1 Parent(s): a247854

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -37
app.py CHANGED
@@ -5,9 +5,7 @@ import os
5
  from difflib import SequenceMatcher
6
  from jiwer import wer
7
  import torchaudio
8
- import torch
9
- from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, HubertForCTC
10
- import whisper
11
 
12
  # Load metadata
13
  with open("common_voice_en_validated_249_hf_ready.json") as f:
@@ -18,44 +16,21 @@ ages = sorted(set(entry["age"] for entry in data))
18
  genders = sorted(set(entry["gender"] for entry in data))
19
  accents = sorted(set(entry["accent"] for entry in data))
20
 
21
- # Load models
22
- device = "cuda" if torch.cuda.is_available() else "cpu"
23
 
24
- # Whisper
25
- whisper_model = whisper.load_model("medium").to(device)
26
-
27
- # Wav2Vec2
28
- wav2vec_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
29
- wav2vec_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to(device)
30
-
31
- # HuBERT
32
- hubert_processor = HubertProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
33
- hubert_model = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft").to(device)
34
 
35
  def load_audio(file_path):
36
  waveform, sr = torchaudio.load(file_path)
37
  return torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)[0].numpy()
38
 
39
- def transcribe_whisper(file_path):
40
- result = whisper_model.transcribe(file_path)
41
  return result["text"].strip().lower()
42
 
43
- def transcribe_wav2vec(file_path):
44
- audio = load_audio(file_path)
45
- inputs = wav2vec_processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
46
- with torch.no_grad():
47
- logits = wav2vec_model(**inputs.to(device)).logits
48
- predicted_ids = torch.argmax(logits, dim=-1)
49
- return wav2vec_processor.batch_decode(predicted_ids)[0].strip().lower()
50
-
51
- def transcribe_hubert(file_path):
52
- audio = load_audio(file_path)
53
- inputs = hubert_processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
54
- with torch.no_grad():
55
- logits = hubert_model(**inputs.to(device)).logits
56
- predicted_ids = torch.argmax(logits, dim=-1)
57
- return hubert_processor.batch_decode(predicted_ids)[0].strip().lower()
58
-
59
  def highlight_differences(ref, hyp):
60
  sm = SequenceMatcher(None, ref.split(), hyp.split())
61
  result = []
@@ -79,9 +54,9 @@ def run_demo(age, gender, accent):
79
  file_path = os.path.join("common_voice_en_validated_249", sample["path"])
80
  gold = sample["sentence"].strip().lower()
81
 
82
- whisper_text = transcribe_whisper(file_path)
83
- wav2vec_text = transcribe_wav2vec(file_path)
84
- hubert_text = transcribe_hubert(file_path)
85
 
86
  table = f"""
87
  <table border="1" style="width:100%">
@@ -118,4 +93,4 @@ with gr.Blocks() as demo:
118
  gr.Textbox(label="Path")
119
  ])
120
 
121
- demo.launch()
 
5
  from difflib import SequenceMatcher
6
  from jiwer import wer
7
  import torchaudio
8
+ from transformers import pipeline
 
 
9
 
10
  # Load metadata
11
  with open("common_voice_en_validated_249_hf_ready.json") as f:
 
16
  genders = sorted(set(entry["gender"] for entry in data))
17
  accents = sorted(set(entry["accent"] for entry in data))
18
 
19
+ # Load pipelines
20
+ device = 0 # 0 for CUDA/GPU, -1 for CPU
21
 
22
+ pipe_whisper = pipeline("automatic-speech-recognition", model="openai/whisper-medium", device=device)
23
+ pipe_wav2vec2 = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h", device=device)
24
+ pipe_hubert = pipeline("automatic-speech-recognition", model="facebook/hubert-base-ls960", device=device)
 
 
 
 
 
 
 
25
 
26
  def load_audio(file_path):
27
  waveform, sr = torchaudio.load(file_path)
28
  return torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)[0].numpy()
29
 
30
+ def transcribe(pipe, file_path):
31
+ result = pipe(file_path)
32
  return result["text"].strip().lower()
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def highlight_differences(ref, hyp):
35
  sm = SequenceMatcher(None, ref.split(), hyp.split())
36
  result = []
 
54
  file_path = os.path.join("common_voice_en_validated_249", sample["path"])
55
  gold = sample["sentence"].strip().lower()
56
 
57
+ whisper_text = transcribe(pipe_whisper, file_path)
58
+ wav2vec_text = transcribe(pipe_wav2vec2, file_path)
59
+ hubert_text = transcribe(pipe_hubert, file_path)
60
 
61
  table = f"""
62
  <table border="1" style="width:100%">
 
93
  gr.Textbox(label="Path")
94
  ])
95
 
96
+ demo.launch()