Luigi commited on
Commit
1b8a47d
·
1 Parent(s): d0509e1

add s2t conversion, enable spk diaraization bydefault

Browse files
Files changed (1) hide show
  1. app.py +30 -28
app.py CHANGED
@@ -7,6 +7,7 @@ import gradio as gr
7
  from transformers import pipeline
8
  from pydub import AudioSegment
9
  from pyannote.audio import Pipeline as DiarizationPipeline
 
10
 
11
  import spaces # zeroGPU support
12
  from funasr import AutoModel
@@ -14,36 +15,21 @@ from funasr.utils.postprocess_utils import rich_transcription_postprocess
14
 
15
  # —————— Model Lists ——————
16
  WHISPER_MODELS = [
 
17
  "openai/whisper-large-v3-turbo",
18
  "openai/whisper-large-v3",
19
- "openai/whisper-tiny",
20
- "openai/whisper-small",
21
  "openai/whisper-medium",
 
22
  "openai/whisper-base",
 
 
23
  "JacobLinCool/whisper-large-v3-turbo-common_voice_19_0-zh-TW",
24
  "Jingmiao/whisper-small-zh_tw",
25
  "DDTChen/whisper-medium-zh-tw",
26
  "kimbochen/whisper-small-zh-tw",
27
- "JacobLinCool/whisper-large-v3-turbo-zh-TW-clean-1",
28
- "JunWorks/whisper-small-zhTW",
29
- "WANGTINGTING/whisper-large-v2-zh-TW-vol2",
30
- "xmzhu/whisper-tiny-zh-TW",
31
- "ingrenn/whisper-small-common-voice-13-zh-TW",
32
- "jun-han/whisper-small-zh-TW",
33
- "xmzhu/whisper-tiny-zh-TW-baseline",
34
- "JacobLinCool/whisper-large-v3-turbo-common_voice_16_1-zh-TW-2",
35
- "JacobLinCool/whisper-large-v3-common_voice_19_0-zh-TW-full-1",
36
- "momo103197/whisper-small-zh-TW-mix",
37
- "JacobLinCool/whisper-large-v3-turbo-zh-TW-clean-1-merged",
38
- "JacobLinCool/whisper-large-v2-common_voice_19_0-zh-TW-full-1",
39
- "kimas1269/whisper-meduim_zhtw",
40
- "JunWorks/whisper-base-zhTW",
41
- "JunWorks/whisper-small-zhTW-frozenDecoder",
42
- "sandy1990418/whisper-large-v3-turbo-zh-tw",
43
- "JacobLinCool/whisper-large-v3-turbo-common_voice_16_1-zh-TW-pissa-merged",
44
- "momo103197/whisper-small-zh-TW-16",
45
- "k1nto/Belle-whisper-large-v3-zh-punct-ct2"
46
  ]
 
47
  SENSEVOICE_MODELS = [
48
  "FunAudioLLM/SenseVoiceSmall",
49
  "AXERA-TECH/SenseVoice",
@@ -65,6 +51,7 @@ WHISPER_LANGUAGES = [
65
  "th","tk","tl","tr","tt","uk","ur","uz","vi","yi","yo",
66
  "zh","yue"
67
  ]
 
68
  SENSEVOICE_LANGUAGES = ["auto", "zh", "yue", "en", "ja", "ko", "nospeech"]
69
 
70
  # —————— Caches ——————
@@ -72,6 +59,9 @@ whisper_pipes = {}
72
  sense_models = {}
73
  dar_pipe = None
74
 
 
 
 
75
  # —————— Helpers ——————
76
  def get_whisper_pipe(model_id: str, device: int):
77
  key = (model_id, device)
@@ -105,14 +95,14 @@ def get_diarization_pipe():
105
  if dar_pipe is None:
106
  # Pull token from environment (HF_TOKEN or HUGGINGFACE_TOKEN)
107
  token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
108
- # Attempt to load the latest 3.1 pipeline, fallback to 2.1 if gated segmentation-3.0 isn't accepted
109
  try:
110
  dar_pipe = DiarizationPipeline.from_pretrained(
111
  "pyannote/speaker-diarization-3.1",
112
  use_auth_token=token or True
113
  )
114
  except Exception as e:
115
- print(f"Failed to load pyannote/speaker-diarization-3.1: {e} Falling back to pyannote/[email protected].")
116
  dar_pipe = DiarizationPipeline.from_pretrained(
117
  "pyannote/[email protected]",
118
  use_auth_token=token or True
@@ -133,6 +123,8 @@ def transcribe_whisper(model_id: str,
133
  result = (pipe(audio_path) if language == "auto"
134
  else pipe(audio_path, generate_kwargs={"language": language}))
135
  transcript = result.get("text", "").strip()
 
 
136
  diar_text = ""
137
  # optional speaker diarization
138
  if enable_diar:
@@ -148,6 +140,8 @@ def transcribe_whisper(model_id: str,
148
  else pipe(tmp.name, generate_kwargs={"language": language}))
149
  os.unlink(tmp.name)
150
  text = seg_out.get("text", "").strip()
 
 
151
  snippets.append(f"[{speaker}] {text}")
152
  diar_text = "\n".join(snippets)
153
  return transcript, diar_text
@@ -173,6 +167,8 @@ def transcribe_sense(model_id: str,
173
  text = rich_transcription_postprocess(segs[0]['text'])
174
  if not enable_punct:
175
  text = re.sub(r"[^\w\s]", "", text)
 
 
176
  return text, ""
177
  # with diarization
178
  diarizer = get_diarization_pipe()
@@ -196,6 +192,8 @@ def transcribe_sense(model_id: str,
196
  txt = rich_transcription_postprocess(segs[0]['text'])
197
  if not enable_punct:
198
  txt = re.sub(r"[^\w\s]", "", txt)
 
 
199
  snippets.append(f"[{speaker}] {txt}")
200
  full = rich_transcription_postprocess(model.generate(
201
  input=audio_path,
@@ -208,20 +206,23 @@ def transcribe_sense(model_id: str,
208
  )[0]['text'])
209
  if not enable_punct:
210
  full = re.sub(r"[^\w\s]", "", full)
 
211
  return full, "\n".join(snippets)
212
 
213
  # —————— Gradio UI ——————
214
  demo = gr.Blocks()
215
  with demo:
216
- gr.Markdown("## Whisper vs. SenseVoice (Language, Device & Diarization)")
 
217
  audio_input = gr.Audio(sources=["upload","microphone"], type="filepath", label="Audio Input")
 
218
  with gr.Row():
219
  with gr.Column():
220
  gr.Markdown("### Whisper ASR")
221
  whisper_dd = gr.Dropdown(choices=WHISPER_MODELS, value=WHISPER_MODELS[0], label="Whisper Model")
222
  whisper_lang = gr.Dropdown(choices=WHISPER_LANGUAGES, value="auto", label="Whisper Language")
223
  device_radio = gr.Radio(choices=["GPU","CPU"], value="GPU", label="Device")
224
- diar_check = gr.Checkbox(label="Enable Diarization")
225
  btn_w = gr.Button("Transcribe with Whisper")
226
  out_w = gr.Textbox(label="Transcript")
227
  out_w_d = gr.Textbox(label="Diarized Transcript")
@@ -232,13 +233,14 @@ with demo:
232
  gr.Markdown("### FunASR SenseVoice ASR")
233
  sense_dd = gr.Dropdown(choices=SENSEVOICE_MODELS, value=SENSEVOICE_MODELS[0], label="SenseVoice Model")
234
  sense_lang = gr.Dropdown(choices=SENSEVOICE_LANGUAGES, value="auto", label="SenseVoice Language")
235
- punct = gr.Checkbox(label="Enable Punctuation", value=True)
236
- diar_s = gr.Checkbox(label="Enable Diarization")
237
  btn_s = gr.Button("Transcribe with SenseVoice")
238
  out_s = gr.Textbox(label="Transcript")
239
  out_s_d = gr.Textbox(label="Diarized Transcript")
240
  btn_s.click(fn=transcribe_sense,
241
- inputs=[sense_dd, sense_lang, audio_input, punct, diar_s],
242
  outputs=[out_s, out_s_d])
 
243
  if __name__ == "__main__":
244
  demo.launch()
 
7
  from transformers import pipeline
8
  from pydub import AudioSegment
9
  from pyannote.audio import Pipeline as DiarizationPipeline
10
+ import opencc
11
 
12
  import spaces # zeroGPU support
13
  from funasr import AutoModel
 
15
 
16
  # —————— Model Lists ——————
17
  WHISPER_MODELS = [
18
+ # Base Whisper models
19
  "openai/whisper-large-v3-turbo",
20
  "openai/whisper-large-v3",
 
 
21
  "openai/whisper-medium",
22
+ "openai/whisper-small",
23
  "openai/whisper-base",
24
+ "openai/whisper-tiny",
25
+ # Community fine-tuned Chinese models
26
  "JacobLinCool/whisper-large-v3-turbo-common_voice_19_0-zh-TW",
27
  "Jingmiao/whisper-small-zh_tw",
28
  "DDTChen/whisper-medium-zh-tw",
29
  "kimbochen/whisper-small-zh-tw",
30
+ # ...etc...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  ]
32
+
33
  SENSEVOICE_MODELS = [
34
  "FunAudioLLM/SenseVoiceSmall",
35
  "AXERA-TECH/SenseVoice",
 
51
  "th","tk","tl","tr","tt","uk","ur","uz","vi","yi","yo",
52
  "zh","yue"
53
  ]
54
+
55
  SENSEVOICE_LANGUAGES = ["auto", "zh", "yue", "en", "ja", "ko", "nospeech"]
56
 
57
  # —————— Caches ——————
 
59
  sense_models = {}
60
  dar_pipe = None
61
 
62
+ # Initialize OpenCC converter for simplified to traditional Chinese
63
+ converter = opencc.OpenCC('s2t.json')
64
+
65
  # —————— Helpers ——————
66
  def get_whisper_pipe(model_id: str, device: int):
67
  key = (model_id, device)
 
95
  if dar_pipe is None:
96
  # Pull token from environment (HF_TOKEN or HUGGINGFACE_TOKEN)
97
  token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
98
+ # Try loading latest 3.1 pipeline, fallback to 2.1 on gated model error
99
  try:
100
  dar_pipe = DiarizationPipeline.from_pretrained(
101
  "pyannote/speaker-diarization-3.1",
102
  use_auth_token=token or True
103
  )
104
  except Exception as e:
105
+ print(f"Failed to load pyannote/speaker-diarization-3.1: {e}\nFalling back to pyannote/[email protected].")
106
  dar_pipe = DiarizationPipeline.from_pretrained(
107
  "pyannote/[email protected]",
108
  use_auth_token=token or True
 
123
  result = (pipe(audio_path) if language == "auto"
124
  else pipe(audio_path, generate_kwargs={"language": language}))
125
  transcript = result.get("text", "").strip()
126
+ # convert simplified Chinese to traditional
127
+ transcript = converter.convert(transcript)
128
  diar_text = ""
129
  # optional speaker diarization
130
  if enable_diar:
 
140
  else pipe(tmp.name, generate_kwargs={"language": language}))
141
  os.unlink(tmp.name)
142
  text = seg_out.get("text", "").strip()
143
+ # convert simplified Chinese to traditional
144
+ text = converter.convert(text)
145
  snippets.append(f"[{speaker}] {text}")
146
  diar_text = "\n".join(snippets)
147
  return transcript, diar_text
 
167
  text = rich_transcription_postprocess(segs[0]['text'])
168
  if not enable_punct:
169
  text = re.sub(r"[^\w\s]", "", text)
170
+ # convert simplified Chinese to traditional
171
+ text = converter.convert(text)
172
  return text, ""
173
  # with diarization
174
  diarizer = get_diarization_pipe()
 
192
  txt = rich_transcription_postprocess(segs[0]['text'])
193
  if not enable_punct:
194
  txt = re.sub(r"[^\w\s]", "", txt)
195
+ # convert simplified Chinese to traditional
196
+ txt = converter.convert(txt)
197
  snippets.append(f"[{speaker}] {txt}")
198
  full = rich_transcription_postprocess(model.generate(
199
  input=audio_path,
 
206
  )[0]['text'])
207
  if not enable_punct:
208
  full = re.sub(r"[^\w\s]", "", full)
209
+ full = converter.convert(full)
210
  return full, "\n".join(snippets)
211
 
212
  # —————— Gradio UI ——————
213
  demo = gr.Blocks()
214
  with demo:
215
+ gr.Markdown("## Whisper vs. SenseVoice (Language, Device & Diarization with Simplified→Traditional Chinese)")
216
+
217
  audio_input = gr.Audio(sources=["upload","microphone"], type="filepath", label="Audio Input")
218
+
219
  with gr.Row():
220
  with gr.Column():
221
  gr.Markdown("### Whisper ASR")
222
  whisper_dd = gr.Dropdown(choices=WHISPER_MODELS, value=WHISPER_MODELS[0], label="Whisper Model")
223
  whisper_lang = gr.Dropdown(choices=WHISPER_LANGUAGES, value="auto", label="Whisper Language")
224
  device_radio = gr.Radio(choices=["GPU","CPU"], value="GPU", label="Device")
225
+ diar_check = gr.Checkbox(label="Enable Diarization", value=True)
226
  btn_w = gr.Button("Transcribe with Whisper")
227
  out_w = gr.Textbox(label="Transcript")
228
  out_w_d = gr.Textbox(label="Diarized Transcript")
 
233
  gr.Markdown("### FunASR SenseVoice ASR")
234
  sense_dd = gr.Dropdown(choices=SENSEVOICE_MODELS, value=SENSEVOICE_MODELS[0], label="SenseVoice Model")
235
  sense_lang = gr.Dropdown(choices=SENSEVOICE_LANGUAGES, value="auto", label="SenseVoice Language")
236
+ punct_chk = gr.Checkbox(label="Enable Punctuation", value=True)
237
+ diar_s_chk = gr.Checkbox(label="Enable Diarization", value=True)
238
  btn_s = gr.Button("Transcribe with SenseVoice")
239
  out_s = gr.Textbox(label="Transcript")
240
  out_s_d = gr.Textbox(label="Diarized Transcript")
241
  btn_s.click(fn=transcribe_sense,
242
+ inputs=[sense_dd, sense_lang, audio_input, punct_chk, diar_s_chk],
243
  outputs=[out_s, out_s_d])
244
+
245
  if __name__ == "__main__":
246
  demo.launch()