datxy commited on
Commit
d36ea49
·
verified ·
1 Parent(s): 38c2998

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -54
app.py CHANGED
@@ -1,37 +1,40 @@
1
  # app.py
2
  import os
 
 
 
3
  import spaces
4
  import torch
5
  import gradio as gr
6
  from transformers import pipeline
7
- import tempfile
 
 
8
 
9
  # ================== 配置与可调参数 ==================
10
  MODEL_NAME = "openai/whisper-large-v3"
11
  BATCH_SIZE = 8
12
  FILE_LIMIT_MB = 1000
13
 
14
- # —— 静音切分策略(本次需求)——
15
- SILENCE_GAP = 0.2 # 相邻字符之间的间隔 >= 0.2s 触发切分
16
- MIN_SEG_DUR = 0.30 # 每段最小时长,避免字幕闪烁
 
 
 
 
17
 
18
  # ================== 设备/精度(自动降级,避免 AcceleratorError) ==================
19
  def _pick_device_and_dtype():
20
- """
21
- - GPU 空间:CUDA:0,优先 FP16(sm >= 7)
22
- - ZeroGPU / CPU:强制 CPU(-1) + FP32
23
- - 避免在 CPU/ZeroGPU 上使用 @spaces.GPU 和 float16 导致 AcceleratorError
24
- """
25
  is_zero_gpu = (
26
  os.environ.get("SYSTEM") == "spaces"
27
  and os.environ.get("SPACE_ACCELERATOR", "").lower() == "zero-gpu"
28
  )
29
-
30
  if torch.cuda.is_available() and not is_zero_gpu:
31
  try:
32
- major, minor = torch.cuda.get_device_capability(0)
33
  except Exception:
34
- major = 7 # 安全兜底
35
  use_fp16 = major >= 7
36
  return 0, (torch.float16 if use_fp16 else torch.float32)
37
  else:
@@ -49,7 +52,7 @@ asr = pipeline(
49
  return_timestamps=True, # 让结果包含 chunks 的 (start, end)
50
  )
51
 
52
- # ================== 工具函数 ==================
53
  def _ts(t: float | None) -> str:
54
  """秒 -> SRT 时间戳 00:00:00,000"""
55
  if t is None or t < 0:
@@ -99,62 +102,142 @@ def _char_timeline(chunk: dict) -> list[tuple[str, float, float]]:
99
  cur = nxt
100
  return timeline
101
 
102
- # ================== 按静音间隔切分(不依赖标点) ==================
103
- def _segment_by_silence(char_stream: list[tuple[str, float, float]],
104
- silence_gap: float = SILENCE_GAP,
105
- min_seg_dur: float = MIN_SEG_DUR) -> list[tuple[float, float, str]]:
106
  """
107
- 按相邻字符的“时间间隔”切分:
108
- - 若当前字符开始时间 上一字符结束时间 的差值 >= silence_gap → 立刻切一段
109
- - 不做标点处理,所有字符原样拼接
110
- - 每段最小时长保护
111
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  segments = []
113
- buf_chars: list[str] = []
114
- buf_start: float | None = None
115
- last_end: float | None = None
116
- prev_end: float | None = None
117
 
118
- def flush():
119
- nonlocal buf_chars, buf_start, last_end
120
  if not buf_chars:
121
- return
122
- text = "".join(buf_chars).strip()
123
- st = buf_start if buf_start is not None else 0.0
124
- en = last_end if last_end is not None else st
125
- if en - st < min_seg_dur:
126
- en = st + min_seg_dur
127
- segments.append((st, en, text))
128
- buf_chars, buf_start, last_end = [], None, None
 
 
129
 
 
130
  for ch, st, en in char_stream:
131
- if buf_start is None:
132
- buf_start = st
133
- if prev_end is not None and (st - prev_end) >= silence_gap:
134
- flush()
135
- buf_start = st
 
 
136
  buf_chars.append(ch)
137
- last_end = en
138
- prev_end = en
139
 
140
- flush()
 
141
  return segments
142
 
143
- def chunks_to_srt_no_number(chunks: list[dict]) -> str:
144
  """
145
- 规范化 chunks → 全局逐字符时间线 → 按静音切分 → 输出无编号 SRT
146
  """
147
  norm = _norm_chunks(chunks)
148
  char_stream = []
149
  for ch in norm:
150
  char_stream.extend(_char_timeline(ch))
151
 
152
- segs = _segment_by_silence(char_stream)
 
 
 
 
153
 
154
  lines = []
155
  for st, en, tx in segs:
156
  lines.append(f"{_ts(st)} --> {_ts(en)}")
157
- lines.append(tx.strip())
158
  lines.append("")
159
  return "\n".join(lines).strip() + ("\n" if lines else "")
160
 
@@ -177,13 +260,12 @@ def transcribe_file_to_srt(audio_path: str, task: str):
177
 
178
  # 运行 ASR
179
  result = asr(audio_path, batch_size=BATCH_SIZE, generate_kwargs={"task": task})
180
- # 兼容 transformers 返回格式
181
  text = (result.get("text") or "").strip() if isinstance(result, dict) else ""
182
  chunks = result.get("chunks") if isinstance(result, dict) else None
183
  chunks = chunks or []
184
 
185
- # 基于静音切分生成 SRT(无编号)
186
- srt_str = chunks_to_srt_no_number(chunks)
187
 
188
  # 兜底:若无 chunks,则给出整段兜底字幕
189
  if not srt_str and text:
@@ -205,12 +287,14 @@ demo = gr.Interface(
205
  gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
206
  ],
207
  outputs=[
208
- gr.Textbox(label="Transcript (SRT Preview — silence-based, no numbering)", lines=18),
209
  gr.File(label="Download SRT"),
210
  ],
211
- title="Upload Audio → SRT (Silence-Based Segmentation, No Numbering)",
212
- description=f"切分规则:相邻字符静音间隔 ≥ {SILENCE_GAP}s;不依赖标点。模型:{MODEL_NAME}",
 
 
213
  allow_flagging="never",
214
  )
215
 
216
- demo.queue().launch()
 
1
  # app.py
2
  import os
3
+ import tempfile
4
+ import numpy as np
5
+
6
  import spaces
7
  import torch
8
  import gradio as gr
9
  from transformers import pipeline
10
+
11
+ # 用于音频能量分析
12
+ import librosa
13
 
14
  # ================== 配置与可调参数 ==================
15
  MODEL_NAME = "openai/whisper-large-v3"
16
  BATCH_SIZE = 8
17
  FILE_LIMIT_MB = 1000
18
 
19
+ # —— 静音切分策略(能量/VAD风格)——
20
+ SILENCE_MIN_LEN = 0.20 # 静音段最短持续时间(秒),静音连续 此长度才作为切分点
21
+ FRAME_LEN_MS = 25 # 能量分析帧长(毫秒)
22
+ HOP_LEN_MS = 10 # 帧移(毫秒)
23
+ DB_DROP = 25.0 # 相对峰值下落阈值(最大能量-25dB 以下视为静音候选)
24
+ PCTL_FLOOR = 20.0 # 也参考能量第20分位数,以避免环境底噪过低造成过切
25
+ MIN_SEG_DUR = 0.30 # 每段最小时长,避免字幕闪烁
26
 
27
  # ================== 设备/精度(自动降级,避免 AcceleratorError) ==================
28
  def _pick_device_and_dtype():
 
 
 
 
 
29
  is_zero_gpu = (
30
  os.environ.get("SYSTEM") == "spaces"
31
  and os.environ.get("SPACE_ACCELERATOR", "").lower() == "zero-gpu"
32
  )
 
33
  if torch.cuda.is_available() and not is_zero_gpu:
34
  try:
35
+ major, _ = torch.cuda.get_device_capability(0)
36
  except Exception:
37
+ major = 7
38
  use_fp16 = major >= 7
39
  return 0, (torch.float16 if use_fp16 else torch.float32)
40
  else:
 
52
  return_timestamps=True, # 让结果包含 chunks 的 (start, end)
53
  )
54
 
55
+ # ================== 基础工具 ==================
56
  def _ts(t: float | None) -> str:
57
  """秒 -> SRT 时间戳 00:00:00,000"""
58
  if t is None or t < 0:
 
102
  cur = nxt
103
  return timeline
104
 
105
+ # ================== 音频能量分析(librosa) ==================
106
+ def _load_audio_mono(audio_path: str, sr: int = 16000):
 
 
107
  """
108
+ 使用 librosa 读取为单声道、目标采样率。
109
+ 返回: (y: np.ndarray[float32], sr: int)
 
 
110
  """
111
+ y, ysr = librosa.load(audio_path, sr=sr, mono=True)
112
+ if y.size == 0:
113
+ return np.zeros(1, dtype=np.float32), sr
114
+ return y.astype(np.float32), sr
115
+
116
+ def _detect_silence_cuts(audio_path: str,
117
+ min_silence_len: float = SILENCE_MIN_LEN,
118
+ frame_len_ms: float = FRAME_LEN_MS,
119
+ hop_len_ms: float = HOP_LEN_MS,
120
+ db_drop: float = DB_DROP,
121
+ pctl_floor: float = PCTL_FLOOR) -> tuple[list[float], float]:
122
+ """
123
+ 返回切分边界时间点(秒)(在静音段的中心点),以及音频总时长。
124
+ - 阈值为 max_db - db_drop 与 第 pctl_floor 分位数二者取较高者(更保守)
125
+ - 连续低于阈值且持续 >= min_silence_len 记为静音段
126
+ """
127
+ y, sr = _load_audio_mono(audio_path, sr=16000)
128
+ dur = len(y) / float(sr)
129
+ if dur <= 0:
130
+ return [], 0.0
131
+
132
+ frame_len = int(sr * frame_len_ms / 1000.0) # e.g., 25ms
133
+ hop_len = int(sr * hop_len_ms / 1000.0) # e.g., 10ms
134
+ frame_len = max(frame_len, 1)
135
+ hop_len = max(hop_len, 1)
136
+
137
+ # RMS 能量(librosa 返回 shape=(1, n_frames))
138
+ rms = librosa.feature.rms(y=y, frame_length=frame_len, hop_length=hop_len, center=True)[0]
139
+ # 转 dB,避免 log(0)
140
+ rms_db = 20.0 * np.log10(np.maximum(rms, 1e-9))
141
+
142
+ max_db = float(np.max(rms_db))
143
+ floor_db = float(np.percentile(rms_db, pctl_floor))
144
+ thr = max(max_db - db_drop, floor_db) # 自适应阈值
145
+
146
+ below = rms_db < thr
147
+ # 找连续 True 的区间
148
+ cuts = []
149
+ i = 0
150
+ n = len(below)
151
+ min_frames = int(np.ceil(min_silence_len * sr / hop_len)) # 以 hop 为单位的最小连续帧数
152
+
153
+ while i < n:
154
+ if below[i]:
155
+ j = i + 1
156
+ while j < n and below[j]:
157
+ j += 1
158
+ # [i, j) 是连续静音
159
+ span_frames = j - i
160
+ if span_frames * hop_len / sr >= min_silence_len:
161
+ # 取静音段中心作为切分点(更稳)
162
+ mid_frame = i + span_frames // 2
163
+ cut_time = mid_frame * hop_len / sr
164
+ # 忽略过于靠近 0 或 末尾的切分点
165
+ if 0.05 < cut_time < (dur - 0.05):
166
+ cuts.append(float(cut_time))
167
+ i = j
168
+ else:
169
+ i += 1
170
+
171
+ return cuts, dur
172
+
173
+ # ================== 依据静音切分边界,合并字符时间线 ==================
174
+ def _segment_by_energy(char_stream: list[tuple[str, float, float]],
175
+ cut_times: list[float],
176
+ min_seg_dur: float = MIN_SEG_DUR) -> list[tuple[float, float, str]]:
177
+ """
178
+ 将全局逐字符时间线按 cut_times(静音中心点)切分。
179
+ """
180
+ if not char_stream:
181
+ return []
182
+
183
+ start_time = char_stream[0][1]
184
+ end_time = char_stream[-1][2]
185
+ boundaries = [t for t in cut_times if start_time < t < end_time]
186
+ boundaries = sorted(set(boundaries))
187
+
188
  segments = []
189
+ buf_chars = []
190
+ seg_start = start_time
191
+ boundary_idx = 0
 
192
 
193
+ def flush(seg_end):
194
+ nonlocal buf_chars, seg_start
195
  if not buf_chars:
196
+ # 即使没有字符(极端情况),也保护时长
197
+ st = seg_start
198
+ en = max(seg_start + min_seg_dur, seg_end)
199
+ segments.append((st, en, ""))
200
+ else:
201
+ st = seg_start
202
+ en = max(seg_start + min_seg_dur, seg_end)
203
+ text = "".join(buf_chars).strip()
204
+ segments.append((st, en, text))
205
+ buf_chars = []
206
 
207
+ # 逐字符推进,遇到边界就 flush
208
  for ch, st, en in char_stream:
209
+ # 处理可能跨越多个边界的字符(罕见)
210
+ while boundary_idx < len(boundaries) and boundaries[boundary_idx] <= st:
211
+ cut_t = boundaries[boundary_idx]
212
+ flush(seg_end=cut_t)
213
+ seg_start = cut_t
214
+ boundary_idx += 1
215
+
216
  buf_chars.append(ch)
 
 
217
 
218
+ # 收尾
219
+ flush(seg_end=end_time)
220
  return segments
221
 
222
+ def chunks_to_srt_no_number(chunks: list[dict], audio_path: str) -> str:
223
  """
224
+ 规范化 chunks → 全局逐字符时间线 → 基于能量分析得到静音边界切分 → 输出无编号 SRT
225
  """
226
  norm = _norm_chunks(chunks)
227
  char_stream = []
228
  for ch in norm:
229
  char_stream.extend(_char_timeline(ch))
230
 
231
+ # 从音频里检测静音切分点
232
+ cut_times, _dur = _detect_silence_cuts(audio_path)
233
+
234
+ # 用能量切分边界来合并字符
235
+ segs = _segment_by_energy(char_stream, cut_times)
236
 
237
  lines = []
238
  for st, en, tx in segs:
239
  lines.append(f"{_ts(st)} --> {_ts(en)}")
240
+ lines.append(tx)
241
  lines.append("")
242
  return "\n".join(lines).strip() + ("\n" if lines else "")
243
 
 
260
 
261
  # 运行 ASR
262
  result = asr(audio_path, batch_size=BATCH_SIZE, generate_kwargs={"task": task})
 
263
  text = (result.get("text") or "").strip() if isinstance(result, dict) else ""
264
  chunks = result.get("chunks") if isinstance(result, dict) else None
265
  chunks = chunks or []
266
 
267
+ # 基于“能量静音”生成 SRT(无编号)
268
+ srt_str = chunks_to_srt_no_number(chunks, audio_path)
269
 
270
  # 兜底:若无 chunks,则给出整段兜底字幕
271
  if not srt_str and text:
 
287
  gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
288
  ],
289
  outputs=[
290
+ gr.Textbox(label="Transcript (SRT Preview — energy-based silence, no numbering)", lines=18),
291
  gr.File(label="Download SRT"),
292
  ],
293
+ title="Upload Audio → SRT (Energy/VAD-Based Segmentation, No Numbering)",
294
+ description=(
295
+ f"切分规则:检测音频帧能量(RMS→dB),当静音连续 ≥ {SILENCE_MIN_LEN}s 时在中心点切分;不依赖标点。模型:{MODEL_NAME}"
296
+ ),
297
  allow_flagging="never",
298
  )
299
 
300
+ demo.queue().launch()