datxy commited on
Commit
8935a60
·
verified ·
1 Parent(s): 93526af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -196
app.py CHANGED
@@ -1,12 +1,10 @@
1
- import os
2
- import io
3
- import math
4
- import tempfile
5
  from typing import List, Tuple
6
 
7
  import numpy as np
8
  import gradio as gr
9
  import librosa
 
10
 
11
  try:
12
  from scipy.ndimage import median_filter
@@ -14,27 +12,31 @@ try:
14
  except Exception:
15
  _HAS_SCIPY = False
16
 
17
- import torch
18
  from transformers import pipeline
 
19
 
20
- # ================== 默认参数(UI 初始值 & 限制) ==================
21
  MODEL_NAME = "openai/whisper-large-v3"
22
  BATCH_SIZE = 8
23
  FILE_LIMIT_MB = 1000
24
 
25
- DEF_SILENCE_MIN_LEN = 0.45 # 停顿(静音段)最短持续秒数
26
- DEF_DB_DROP = 25.0 # 相对峰值下落阈值(max_db - DB_DROP)
27
- DEF_PCTL_FLOOR = 20.0 # 能量分位(dB)下限(越大越保守)
28
- DEF_MIN_SEG_DUR = 1.00 # 每段最短显示时长
29
- DEF_FRAME_LEN_MS = 25 # 能量分析帧长
30
- DEF_HOP_LEN_MS = 10 # 帧移
31
- DEF_CUT_OFFSET_SEC = 0.00 # 切分偏移(整体校准)
32
- DEF_CHUNK_LEN_S = 20 # ASR 分块长度,越小时间漂移越小
33
- DEF_STRIDE_LEN_S = 2 # ASR 重叠长度,帮助跨块稳定时间戳
34
- SR_TARGET = 16000 # 统一采样率(Whisper 期望 16k)
35
-
36
- # ================== ASR Pipeline ==================
37
- def _get_device_and_dtype():
 
 
 
 
38
  if torch.cuda.is_available():
39
  return 0, torch.float16
40
  elif torch.backends.mps.is_available():
@@ -42,32 +44,38 @@ def _get_device_and_dtype():
42
  else:
43
  return -1, torch.float32
44
 
45
- DEVICE, DTYPE = _get_device_and_dtype()
46
-
47
- asr = pipeline(
48
- task="automatic-speech-recognition",
49
- model=MODEL_NAME,
50
- device=DEVICE,
51
- torch_dtype=DTYPE,
52
- # 关键:词级时间戳
53
- return_timestamps="word",
54
- chunk_length_s=DEF_CHUNK_LEN_S,
55
- stride_length_s=DEF_STRIDE_LEN_S,
56
- )
57
-
58
- # ================== 工具函数 ==================
 
 
 
 
 
 
 
 
 
59
  def _load_audio(path: str, sr: int = SR_TARGET):
60
  y, sr = librosa.load(path, sr=sr, mono=True)
61
  return y, sr
62
 
63
  def _to_db(rms: np.ndarray):
64
- # librosa 的 amplitude_to_db 以幅值为输入
65
- # 这里确保不会除以 0
66
  ref = np.maximum(np.max(rms), 1e-10)
67
- db = 20.0 * np.log10(np.maximum(rms, 1e-10) / ref)
68
- return db
69
 
70
- def _format_ts(sec: float) -> str:
71
  if sec < 0: sec = 0.0
72
  h = int(sec // 3600)
73
  m = int((sec % 3600) // 60)
@@ -76,10 +84,6 @@ def _format_ts(sec: float) -> str:
76
  return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
77
 
78
  def _extract_word_stream(chunks) -> List[Tuple[str, float, float]]:
79
- """
80
- 将 pipeline 返回的 chunks 统一抽取为 [(text, start, end), ...]
81
- 兼容字段名:timestamp/timestamps/start/end/time_start/time_end
82
- """
83
  out = []
84
  if not chunks:
85
  return out
@@ -87,9 +91,8 @@ def _extract_word_stream(chunks) -> List[Tuple[str, float, float]]:
87
  txt = (ch.get("text") or "").strip()
88
  ts = ch.get("timestamp", ch.get("timestamps", None))
89
  if ts is None:
90
- # 有些实现用 start/end 或 time_start/time_end
91
  s = ch.get("start", ch.get("time_start", None))
92
- e = ch.get("end", ch.get("time_end", None))
93
  if s is not None and e is not None and txt:
94
  s = float(s); e = float(e)
95
  if e < s: e = s
@@ -104,75 +107,53 @@ def _extract_word_stream(chunks) -> List[Tuple[str, float, float]]:
104
  def _detect_silence_cuts(
105
  y: np.ndarray,
106
  sr: int,
107
- silence_min_len: float = DEF_SILENCE_MIN_LEN,
108
- db_drop: float = DEF_DB_DROP,
109
- pctl_floor: float = DEF_PCTL_FLOOR,
110
- frame_len_ms: int = DEF_FRAME_LEN_MS,
111
- hop_len_ms: int = DEF_HOP_LEN_MS,
112
- ) -> Tuple[List[float], float]:
113
- """
114
- RMS(dB) + 最低点策略找切分点;返回 [cut_times], total_dur
115
- """
116
- frame_len = int(sr * frame_len_ms / 1000)
117
- hop_len = int(sr * hop_len_ms / 1000)
118
- frame_len = max(256, frame_len)
119
- hop_len = max(64, hop_len)
120
 
121
  rms = librosa.feature.rms(y=y, frame_length=frame_len, hop_length=hop_len, center=True)[0]
122
  rms_db = _to_db(rms)
123
-
124
  if _HAS_SCIPY:
125
- rms_db = median_filter(rms_db, size=5) # 轻微平滑
126
 
127
  max_db = float(np.max(rms_db))
128
  floor_db = float(np.percentile(rms_db, pctl_floor))
129
  thr = max(max_db - db_drop, floor_db)
130
 
131
- # 低于阈值视为“静音/低能”
132
- low_mask = rms_db <= thr
133
-
134
- # 找连续的低能区间
135
  cut_times = []
 
136
  i = 0
137
- n = len(low_mask)
138
- min_frames = int(silence_min_len * sr / hop_len)
139
  while i < n:
140
- if not low_mask[i]:
141
- i += 1
142
- continue
143
  j = i + 1
144
- while j < n and low_mask[j]:
145
  j += 1
146
- span = j - i
147
- if span >= max(1, min_frames):
148
  local = rms_db[i:j]
149
  k = int(np.argmin(local))
150
  best = i + k
151
- cut_t = best * hop_len / sr
152
- cut_times.append(float(cut_t))
153
  i = j
154
 
155
- total_dur = float(len(y) / sr)
156
- # 去重+排序+边界裁剪
157
- cut_times = sorted(set(t for t in cut_times if 0.05 <= t <= total_dur - 0.05))
158
- return cut_times, total_dur
159
-
160
- def _snap_to_word_boundaries(
161
- cut_times: List[float],
162
- word_stream: List[Tuple[str, float, float]],
163
- max_dist: float = 0.25
164
- ) -> List[float]:
165
- if not cut_times or not word_stream:
166
- return cut_times
167
- bounds = []
168
- for _, s, e in word_stream:
169
- bounds.append(s); bounds.append(e)
170
- bounds = sorted(set(bounds))
171
  snapped = []
172
- for t in cut_times:
173
  idx = min(range(len(bounds)), key=lambda i: abs(bounds[i]-t))
174
  snapped.append(bounds[idx] if abs(bounds[idx]-t) <= max_dist else t)
175
- # 去重并保证最小间隔
176
  snapped = sorted(set(snapped))
177
  out = []
178
  for t in snapped:
@@ -180,101 +161,60 @@ def _snap_to_word_boundaries(
180
  out.append(t)
181
  return out
182
 
183
- def _segment_by_energy(
184
- word_stream: List[Tuple[str, float, float]],
185
- cut_times: List[float],
186
- total_dur: float,
187
- min_seg_dur: float = DEF_MIN_SEG_DUR,
188
- ) -> List[Tuple[float, float, str]]:
189
- """
190
- 根据 cut_times 把词流切段;不足 min_seg_dur 的与邻段合并(优先并右)
191
- 返回 [(st, en, text), ...]
192
- """
193
- if not word_stream:
194
- # 没有词流时,返回整段空文本占位
195
- return [(0.0, total_dur, "").copy()]
196
-
197
- bnds = [0.0] + [t for t in cut_times if 0.0 < t < total_dur] + [total_dur]
198
  segs = []
199
- wi = 0
200
- W = len(word_stream)
201
-
202
- for i in range(len(bnds) - 1):
203
  L, R = bnds[i], bnds[i+1]
204
  texts, starts, ends = [], [], []
205
- # 收集与 [L,R] 有交集的词
206
- while wi < W and word_stream[wi][2] <= L:
207
  wi += 1
208
  wj = wi
209
- while wj < W and word_stream[wj][1] < R:
210
- txt, s, e = word_stream[wj]
211
  if e > L and s < R:
212
- texts.append(txt)
213
- starts.append(s)
214
- ends.append(e)
215
  wj += 1
216
-
217
  if texts:
218
- st = float(min(starts)); en = float(max(ends))
219
- # 防止越界
220
- st = max(st, L); en = min(en, R)
221
- tx = " ".join(texts).strip()
222
- segs.append([st, en, tx])
223
- else:
224
- # 没有词但留下时间窗,避免生成过短空段
225
- if (R - L) >= max(0.25, min_seg_dur * 0.5):
226
- segs.append([L, R, ""])
227
-
228
- # 合并过短段(优先向右)
229
- def has_punc(t: str) -> bool:
230
- return any(p in t for p in ",。!?,.!?;;::")
231
 
 
232
  i = 0
233
  while i < len(segs):
234
  st, en, tx = segs[i]
235
- if (en - st) < min_seg_dur and len(segs) > 1:
236
- # 选择合并目标:先右后左;若右/左都存在,则优先含标点的
237
- target = None
238
  cand = []
239
  if i + 1 < len(segs): cand.append(i + 1)
240
  if i - 1 >= 0: cand.append(i - 1)
241
- if not cand:
242
- i += 1
243
- continue
244
- # 优先含标点
245
  cand.sort(key=lambda j: (not has_punc(segs[j][2]), abs(j - i)))
246
- target = cand[0]
247
- # 合并
248
- nst = min(segs[target][0], st)
249
- nen = max(segs[target][1], en)
250
- ntx = " ".join([segs[target][2], tx]).strip() if target < i else " ".join([tx, segs[target][2]]).strip()
251
- # 放在较小的索引位,删除另一段
252
- keep, drop = (target, i) if target < i else (i, target)
253
  segs[keep] = [nst, nen, ntx]
254
  del segs[drop]
255
- # 合并后从较早索引重新审视
256
- i = max(0, keep - 1)
257
- continue
258
  i += 1
259
 
260
- # 去除空文字但极短的段
261
- out = []
262
- for st, en, tx in segs:
263
- if (en - st) < 0.12:
264
- continue
265
- out.append((float(st), float(en), tx.strip()))
266
- return out
267
 
268
- def _build_srt(segs: List[Tuple[float, float, str]]) -> str:
269
  lines = []
270
  for idx, (st, en, tx) in enumerate(segs, start=1):
271
  lines.append(str(idx))
272
- lines.append(f"{_format_ts(st)} --> {_format_ts(en)}")
273
- lines.append(tx if tx else "")
274
  lines.append("")
275
  return "\n".join(lines).strip() + "\n"
276
 
277
- # ================== 主流程 ==================
 
278
  def transcribe_and_split(
279
  audio_path: str,
280
  silence_min_len: float = DEF_SILENCE_MIN_LEN,
@@ -288,30 +228,26 @@ def transcribe_and_split(
288
  if not audio_path:
289
  raise gr.Error("请先上传或录制音频。")
290
 
291
- # 体积限制
292
  try:
293
- fsize_mb = os.path.getsize(audio_path) / (1024 * 1024)
294
- if fsize_mb > FILE_LIMIT_MB:
295
- raise gr.Error(f"文件过大:{fsize_mb:.1f} MB,超过上限 {FILE_LIMIT_MB} MB。")
296
  except Exception:
297
  pass
298
 
299
- # ASR
 
300
  result = asr(
301
  audio_path,
302
- # 以下参数也可传入调用(已在构造时设置了默认)
303
  return_timestamps="word",
304
  chunk_length_s=DEF_CHUNK_LEN_S,
305
  stride_length_s=DEF_STRIDE_LEN_S,
306
  batch_size=BATCH_SIZE,
307
  )
308
  text = (result.get("text") or "").strip()
309
- chunks = result.get("chunks") or []
310
- words = _extract_word_stream(chunks)
311
 
312
- # 能量切分(在 16k 单声道下计算)
313
  y, sr = _load_audio(audio_path, sr=SR_TARGET)
314
- cut_times, total_dur = _detect_silence_cuts(
315
  y, sr,
316
  silence_min_len=silence_min_len,
317
  db_drop=db_drop,
@@ -320,55 +256,49 @@ def transcribe_and_split(
320
  hop_len_ms=hop_len_ms,
321
  )
322
 
323
- # 切点整体偏移
324
  if abs(cut_offset_sec) > 1e-6:
325
- cut_times = [max(0.0, min(total_dur, t + cut_offset_sec)) for t in cut_times]
326
 
327
- # 切点吸附到最近词边界
328
- cut_times = _snap_to_word_boundaries(cut_times, words, max_dist=0.25)
329
-
330
- # 用切点+词流生成分段
331
- segs = _segment_by_energy(words, cut_times, total_dur, min_seg_dur=min_seg_dur)
332
-
333
- # 兜底:如果没有词(极端情况),给整段
334
  if not segs:
335
- segs = [(0.0, total_dur, text)]
 
336
 
337
- srt_text = _build_srt(segs)
338
-
339
- # 保存到临时 .srt 文件供下载
340
  tmpf = tempfile.NamedTemporaryFile(delete=False, suffix=".srt")
341
- tmpf.write(srt_text.encode("utf-8"))
342
- tmpf.flush(); tmpf.close()
343
 
344
- return srt_text, tmpf.name
 
 
 
345
 
346
- # ================== Gradio UI ==================
347
  with gr.Blocks(title="Whisper Large V3 · 智能切分 SRT", theme=gr.themes.Soft()) as demo:
348
  gr.Markdown("### 🎧 Whisper Large V3 · 更稳的 SRT 切分\n"
349
  "- 词级时间戳 + 能量最低点切分 + 词边界吸附\n"
350
- "- 片段时长不足将自动与邻段合并(优先右侧)\n")
 
 
351
 
352
- with gr.Row():
353
- audio = gr.Audio(sources=["upload", "microphone"], type="filepath", label="音频(上传或录制)")
354
  with gr.Accordion("高级参数", open=False):
355
  with gr.Row():
356
  silence_min_len = gr.Slider(0.1, 1.0, value=DEF_SILENCE_MIN_LEN, step=0.05, label="静音最短时长 (s)")
357
- db_drop = gr.Slider(10, 40, value=DEF_DB_DROP, step=1.0, label="相对峰值下落 (dB)")
358
- pctl_floor = gr.Slider(0, 50, value=DEF_PCTL_FLOOR, step=1.0, label="能量分位下限 (dB)")
359
  with gr.Row():
360
  min_seg_dur = gr.Slider(0.3, 3.0, value=DEF_MIN_SEG_DUR, step=0.05, label="最短片段时长 (s)")
361
- frame_len_ms = gr.Slider(10, 50, value=DEF_FRAME_LEN_MS, step=1, label="帧长 (ms)")
362
- hop_len_ms = gr.Slider(5, 25, value=DEF_HOP_LEN_MS, step=1, label="帧移 (ms)")
363
  cut_offset_sec = gr.Slider(-0.20, 0.20, value=DEF_CUT_OFFSET_SEC, step=0.01, label="切分整体偏移 (s)")
364
 
365
  btn = gr.Button("开始识别并生成 SRT", variant="primary")
366
- with gr.Row():
367
- srt_preview = gr.Textbox(lines=16, label="SRT 预览", show_copy_button=True)
368
  srt_file = gr.File(label="下载 SRT 文件", file_count="single")
369
 
370
  btn.click(
371
- fn=transcribe_and_split,
372
  inputs=[audio, silence_min_len, db_drop, pctl_floor, min_seg_dur, frame_len_ms, hop_len_ms, cut_offset_sec],
373
  outputs=[srt_preview, srt_file],
374
  )
 
1
+ import os, io, math, tempfile
 
 
 
2
  from typing import List, Tuple
3
 
4
  import numpy as np
5
  import gradio as gr
6
  import librosa
7
+ import torch
8
 
9
  try:
10
  from scipy.ndimage import median_filter
 
12
  except Exception:
13
  _HAS_SCIPY = False
14
 
 
15
  from transformers import pipeline
16
+ import spaces # 关键:用于 ZeroGPU
17
 
18
+ # ================== 默认参数 ==================
19
  MODEL_NAME = "openai/whisper-large-v3"
20
  BATCH_SIZE = 8
21
  FILE_LIMIT_MB = 1000
22
 
23
+ DEF_SILENCE_MIN_LEN = 0.45
24
+ DEF_DB_DROP = 25.0
25
+ DEF_PCTL_FLOOR = 20.0
26
+ DEF_MIN_SEG_DUR = 1.00
27
+ DEF_FRAME_LEN_MS = 25
28
+ DEF_HOP_LEN_MS = 10
29
+ DEF_CUT_OFFSET_SEC = 0.00
30
+ DEF_CHUNK_LEN_S = 20
31
+ DEF_STRIDE_LEN_S = 2
32
+ SR_TARGET = 16000
33
+
34
+ # ================== 全局懒加载 ==================
35
+ _ASR = None
36
+ _ASR_DEVICE = None
37
+ _ASR_DTYPE = None
38
+
39
+ def _pick_device_dtype():
40
  if torch.cuda.is_available():
41
  return 0, torch.float16
42
  elif torch.backends.mps.is_available():
 
44
  else:
45
  return -1, torch.float32
46
 
47
+ def _get_asr():
48
+ """
49
+ ZeroGPU 下必须在 @spaces.GPU 修饰的函数内首次调用,才能拿到 cuda。
50
+ CPU/常规 GPU 也兼容。
51
+ """
52
+ global _ASR, _ASR_DEVICE, _ASR_DTYPE
53
+ dev, dt = _pick_device_dtype()
54
+ if _ASR is None or _ASR_DEVICE != dev:
55
+ _ASR = pipeline(
56
+ task="automatic-speech-recognition",
57
+ model=MODEL_NAME,
58
+ device=dev,
59
+ torch_dtype=dt,
60
+ return_timestamps="word",
61
+ chunk_length_s=DEF_CHUNK_LEN_S,
62
+ stride_length_s=DEF_STRIDE_LEN_S,
63
+ ignore_warning=True,
64
+ )
65
+ _ASR_DEVICE, _ASR_DTYPE = dev, dt
66
+ print(f"[ASR] Initialized on device={dev} dtype={dt}")
67
+ return _ASR
68
+
69
+ # ================== 音频 & 工具 ==================
70
  def _load_audio(path: str, sr: int = SR_TARGET):
71
  y, sr = librosa.load(path, sr=sr, mono=True)
72
  return y, sr
73
 
74
  def _to_db(rms: np.ndarray):
 
 
75
  ref = np.maximum(np.max(rms), 1e-10)
76
+ return 20.0 * np.log10(np.maximum(rms, 1e-10) / ref)
 
77
 
78
+ def _fmt_ts(sec: float) -> str:
79
  if sec < 0: sec = 0.0
80
  h = int(sec // 3600)
81
  m = int((sec % 3600) // 60)
 
84
  return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
85
 
86
  def _extract_word_stream(chunks) -> List[Tuple[str, float, float]]:
 
 
 
 
87
  out = []
88
  if not chunks:
89
  return out
 
91
  txt = (ch.get("text") or "").strip()
92
  ts = ch.get("timestamp", ch.get("timestamps", None))
93
  if ts is None:
 
94
  s = ch.get("start", ch.get("time_start", None))
95
+ e = ch.get("end", ch.get("time_end", None))
96
  if s is not None and e is not None and txt:
97
  s = float(s); e = float(e)
98
  if e < s: e = s
 
107
  def _detect_silence_cuts(
108
  y: np.ndarray,
109
  sr: int,
110
+ silence_min_len: float,
111
+ db_drop: float,
112
+ pctl_floor: float,
113
+ frame_len_ms: int,
114
+ hop_len_ms: int,
115
+ ):
116
+ frame_len = max(256, int(sr * frame_len_ms / 1000))
117
+ hop_len = max( 64, int(sr * hop_len_ms / 1000))
 
 
 
 
 
118
 
119
  rms = librosa.feature.rms(y=y, frame_length=frame_len, hop_length=hop_len, center=True)[0]
120
  rms_db = _to_db(rms)
 
121
  if _HAS_SCIPY:
122
+ rms_db = median_filter(rms_db, size=5)
123
 
124
  max_db = float(np.max(rms_db))
125
  floor_db = float(np.percentile(rms_db, pctl_floor))
126
  thr = max(max_db - db_drop, floor_db)
127
 
128
+ low = rms_db <= thr
 
 
 
129
  cut_times = []
130
+ n = len(low)
131
  i = 0
132
+ min_frames = max(1, int(silence_min_len * sr / hop_len))
 
133
  while i < n:
134
+ if not low[i]:
135
+ i += 1; continue
 
136
  j = i + 1
137
+ while j < n and low[j]:
138
  j += 1
139
+ if (j - i) >= min_frames:
 
140
  local = rms_db[i:j]
141
  k = int(np.argmin(local))
142
  best = i + k
143
+ cut_times.append(best * hop_len / sr)
 
144
  i = j
145
 
146
+ total = float(len(y) / sr)
147
+ cut_times = sorted(set(t for t in cut_times if 0.05 <= t <= total - 0.05))
148
+ return cut_times, total
149
+
150
+ def _snap_to_word_bounds(cuts: List[float], words: List[Tuple[str, float, float]], max_dist=0.25):
151
+ if not cuts or not words: return cuts
152
+ bounds = sorted({b for _, s, e in words for b in (s, e)})
 
 
 
 
 
 
 
 
 
153
  snapped = []
154
+ for t in cuts:
155
  idx = min(range(len(bounds)), key=lambda i: abs(bounds[i]-t))
156
  snapped.append(bounds[idx] if abs(bounds[idx]-t) <= max_dist else t)
 
157
  snapped = sorted(set(snapped))
158
  out = []
159
  for t in snapped:
 
161
  out.append(t)
162
  return out
163
 
164
+ def _segment(words: List[Tuple[str,float,float]], cuts: List[float], total: float, min_seg: float):
165
+ if not words:
166
+ return [(0.0, total, "")]
167
+ bnds = [0.0] + [t for t in cuts if 0.0 < t < total] + [total]
 
 
 
 
 
 
 
 
 
 
 
168
  segs = []
169
+ wi, W = 0, len(words)
170
+ for i in range(len(bnds)-1):
 
 
171
  L, R = bnds[i], bnds[i+1]
172
  texts, starts, ends = [], [], []
173
+ while wi < W and words[wi][2] <= L:
 
174
  wi += 1
175
  wj = wi
176
+ while wj < W and words[wj][1] < R:
177
+ txt, s, e = words[wj]
178
  if e > L and s < R:
179
+ texts.append(txt); starts.append(s); ends.append(e)
 
 
180
  wj += 1
 
181
  if texts:
182
+ st, en = max(min(starts), L), min(max(ends), R)
183
+ segs.append([float(st), float(en), " ".join(texts).strip()])
184
+ elif (R - L) >= max(0.25, min_seg * 0.5):
185
+ segs.append([L, R, ""])
 
 
 
 
 
 
 
 
 
186
 
187
+ def has_punc(t): return any(p in t for p in ",。!?,.!?;;::")
188
  i = 0
189
  while i < len(segs):
190
  st, en, tx = segs[i]
191
+ if (en - st) < min_seg and len(segs) > 1:
 
 
192
  cand = []
193
  if i + 1 < len(segs): cand.append(i + 1)
194
  if i - 1 >= 0: cand.append(i - 1)
 
 
 
 
195
  cand.sort(key=lambda j: (not has_punc(segs[j][2]), abs(j - i)))
196
+ t = cand[0]
197
+ nst, nen = min(segs[t][0], st), max(segs[t][1], en)
198
+ ntx = (" ".join([segs[t][2], tx]) if t < i else " ".join([tx, segs[t][2]])).strip()
199
+ keep, drop = (t, i) if t < i else (i, t)
 
 
 
200
  segs[keep] = [nst, nen, ntx]
201
  del segs[drop]
202
+ i = max(0, keep - 1); continue
 
 
203
  i += 1
204
 
205
+ return [(st, en, tx.strip()) for st, en, tx in segs if (en - st) >= 0.12]
 
 
 
 
 
 
206
 
207
+ def _build_srt(segs: List[Tuple[float,float,str]]) -> str:
208
  lines = []
209
  for idx, (st, en, tx) in enumerate(segs, start=1):
210
  lines.append(str(idx))
211
+ lines.append(f"{_fmt_ts(st)} --> {_fmt_ts(en)}")
212
+ lines.append(tx)
213
  lines.append("")
214
  return "\n".join(lines).strip() + "\n"
215
 
216
+ # ================== 推理核心(放在 GPU 上执行) ==================
217
+ @spaces.GPU # 关键:ZeroGPU 运行入口(按钮点击会调用它)
218
  def transcribe_and_split(
219
  audio_path: str,
220
  silence_min_len: float = DEF_SILENCE_MIN_LEN,
 
228
  if not audio_path:
229
  raise gr.Error("请先上传或录制音频。")
230
 
 
231
  try:
232
+ if os.path.getsize(audio_path) / (1024*1024) > FILE_LIMIT_MB:
233
+ raise gr.Error(f"文件过大,超过 {FILE_LIMIT_MB} MB。")
 
234
  except Exception:
235
  pass
236
 
237
+ asr = _get_asr() # 在 GPU 上首次创建
238
+
239
  result = asr(
240
  audio_path,
 
241
  return_timestamps="word",
242
  chunk_length_s=DEF_CHUNK_LEN_S,
243
  stride_length_s=DEF_STRIDE_LEN_S,
244
  batch_size=BATCH_SIZE,
245
  )
246
  text = (result.get("text") or "").strip()
247
+ words = _extract_word_stream(result.get("chunks") or [])
 
248
 
 
249
  y, sr = _load_audio(audio_path, sr=SR_TARGET)
250
+ cuts, total = _detect_silence_cuts(
251
  y, sr,
252
  silence_min_len=silence_min_len,
253
  db_drop=db_drop,
 
256
  hop_len_ms=hop_len_ms,
257
  )
258
 
 
259
  if abs(cut_offset_sec) > 1e-6:
260
+ cuts = [max(0.0, min(total, t + cut_offset_sec)) for t in cuts]
261
 
262
+ cuts = _snap_to_word_bounds(cuts, words, max_dist=0.25)
263
+ segs = _segment(words, cuts, total, min_seg_dur)
 
 
 
 
 
264
  if not segs:
265
+ segs = [(0.0, total, text)]
266
+ srt = _build_srt(segs)
267
 
 
 
 
268
  tmpf = tempfile.NamedTemporaryFile(delete=False, suffix=".srt")
269
+ tmpf.write(srt.encode("utf-8")); tmpf.flush(); tmpf.close()
270
+ return srt, tmpf.name
271
 
272
+ # 让启动检查看到 GPU 入口(可选,不调用也行)
273
+ @spaces.GPU
274
+ def gpu_warmup():
275
+ return "ok"
276
 
277
+ # ================== UI ==================
278
  with gr.Blocks(title="Whisper Large V3 · 智能切分 SRT", theme=gr.themes.Soft()) as demo:
279
  gr.Markdown("### 🎧 Whisper Large V3 · 更稳的 SRT 切分\n"
280
  "- 词级时间戳 + 能量最低点切分 + 词边界吸附\n"
281
+ "- 片段过短自动合并,SRT 含序号行\n")
282
+
283
+ audio = gr.Audio(sources=["upload", "microphone"], type="filepath", label="音频(上传或录制)")
284
 
 
 
285
  with gr.Accordion("高级参数", open=False):
286
  with gr.Row():
287
  silence_min_len = gr.Slider(0.1, 1.0, value=DEF_SILENCE_MIN_LEN, step=0.05, label="静音最短时长 (s)")
288
+ db_drop = gr.Slider(10, 40, value=DEF_DB_DROP, step=1.0, label="相对峰值下落 (dB)")
289
+ pctl_floor = gr.Slider(0, 50, value=DEF_PCTL_FLOOR, step=1.0, label="能量分位下限 (dB)")
290
  with gr.Row():
291
  min_seg_dur = gr.Slider(0.3, 3.0, value=DEF_MIN_SEG_DUR, step=0.05, label="最短片段时长 (s)")
292
+ frame_len_ms = gr.Slider(10, 50, value=DEF_FRAME_LEN_MS, step=1, label="帧长 (ms)")
293
+ hop_len_ms = gr.Slider(5, 25, value=DEF_HOP_LEN_MS, step=1, label="帧移 (ms)")
294
  cut_offset_sec = gr.Slider(-0.20, 0.20, value=DEF_CUT_OFFSET_SEC, step=0.01, label="切分整体偏移 (s)")
295
 
296
  btn = gr.Button("开始识别并生成 SRT", variant="primary")
297
+ srt_preview = gr.Textbox(lines=16, label="SRT 预览", show_copy_button=True)
 
298
  srt_file = gr.File(label="下载 SRT 文件", file_count="single")
299
 
300
  btn.click(
301
+ fn=transcribe_and_split, # 注意:绑定的是 @spaces.GPU 函数
302
  inputs=[audio, silence_min_len, db_drop, pctl_floor, min_seg_dur, frame_len_ms, hop_len_ms, cut_offset_sec],
303
  outputs=[srt_preview, srt_file],
304
  )