etchen commited on
Commit
912ccda
·
verified ·
1 Parent(s): 24ef1e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -7,6 +7,9 @@ import torch
7
 
8
  from transformers import pipeline
9
 
 
 
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  model_repo_id = "emlinking/wav2vec2-large-xls-r-300m-tsm-asr-v6"
12
 
@@ -15,23 +18,25 @@ if torch.cuda.is_available():
15
  else:
16
  torch_dtype = torch.float32
17
 
18
- pipe = pipeline(task="automatic-speech-recognition", model=model_repo_id, device=device)
 
19
 
20
  # @spaces.GPU #[uncomment to use ZeroGPU]
21
  def infer(
22
  audio,
23
  target
24
  ):
 
25
  sampling_rate, wav = audio
26
  if wav.ndim > 1:
27
  wav = wav.mean(axis=1)
28
  wav = wav.astype(np.float32)
29
  wav /= np.max(np.abs(wav))
30
- user_pron = pipe(wav)['text']
 
31
 
32
  # compare texts
33
- d = Differ()
34
- d_toks = [(i[2:], i[0] if i[0] != " " else None) for i in d.compare(target, user_pron)]
35
  return (user_pron, d_toks)
36
 
37
  css = """
@@ -52,7 +57,7 @@ with gr.Blocks(css=css) as demo:
52
  label='Comparison',
53
  combine_adjacent=True,
54
  show_legend=True,
55
- color_map={'+': 'red', '-': 'green'}
56
  )
57
  input_audio.input(fn=infer, inputs=[input_audio, target], outputs=[output, diff])
58
 
 
7
 
8
  from transformers import pipeline
9
 
10
+ # ################ CHANGE THIS TO CHANGE THE LANGUAGE ###################### #
11
+ from TaiwaneseHokkien import TaiwaneseHokkien
12
+
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
  model_repo_id = "emlinking/wav2vec2-large-xls-r-300m-tsm-asr-v6"
15
 
 
18
  else:
19
  torch_dtype = torch.float32
20
 
21
+ language = TaiwaneseHokkien(device=device, torch_dtype=torch_dtype)
22
+ # ########################################################################## #
23
 
24
  # @spaces.GPU #[uncomment to use ZeroGPU]
25
  def infer(
26
  audio,
27
  target
28
  ):
29
+ # preprocess
30
  sampling_rate, wav = audio
31
  if wav.ndim > 1:
32
  wav = wav.mean(axis=1)
33
  wav = wav.astype(np.float32)
34
  wav /= np.max(np.abs(wav))
35
+
36
+ user_pron = language.asr(wav)
37
 
38
  # compare texts
39
+ d_toks = language.compare(target, user_pron)
 
40
  return (user_pron, d_toks)
41
 
42
  css = """
 
57
  label='Comparison',
58
  combine_adjacent=True,
59
  show_legend=True,
60
+ color_map=language.compare_colors
61
  )
62
  input_audio.input(fn=infer, inputs=[input_audio, target], outputs=[output, diff])
63