import os, io, math, tempfile from typing import List, Tuple import numpy as np import gradio as gr import librosa import torch try: from scipy.ndimage import median_filter _HAS_SCIPY = True except Exception: _HAS_SCIPY = False from transformers import pipeline import spaces # 关键:用于 ZeroGPU # ================== 默认参数 ================== MODEL_NAME = "openai/whisper-large-v3" BATCH_SIZE = 8 FILE_LIMIT_MB = 1000 DEF_SILENCE_MIN_LEN = 0.45 DEF_DB_DROP = 25.0 DEF_PCTL_FLOOR = 20.0 DEF_MIN_SEG_DUR = 1.00 DEF_FRAME_LEN_MS = 25 DEF_HOP_LEN_MS = 10 DEF_CUT_OFFSET_SEC = 0.00 DEF_CHUNK_LEN_S = 20 DEF_STRIDE_LEN_S = 2 SR_TARGET = 16000 # ================== 全局懒加载 ================== _ASR = None _ASR_DEVICE = None _ASR_DTYPE = None def _pick_device_dtype(): if torch.cuda.is_available(): return 0, torch.float16 elif torch.backends.mps.is_available(): return "mps", torch.float16 else: return -1, torch.float32 def _get_asr(): """ 在 ZeroGPU 下必须在 @spaces.GPU 修饰的函数内首次调用,才能拿到 cuda。 CPU/常规 GPU 也兼容。 """ global _ASR, _ASR_DEVICE, _ASR_DTYPE dev, dt = _pick_device_dtype() if _ASR is None or _ASR_DEVICE != dev: _ASR = pipeline( task="automatic-speech-recognition", model=MODEL_NAME, device=dev, torch_dtype=dt, return_timestamps="word", chunk_length_s=DEF_CHUNK_LEN_S, stride_length_s=DEF_STRIDE_LEN_S, ignore_warning=True, ) _ASR_DEVICE, _ASR_DTYPE = dev, dt print(f"[ASR] Initialized on device={dev} dtype={dt}") return _ASR # ================== 音频 & 工具 ================== def _load_audio(path: str, sr: int = SR_TARGET): y, sr = librosa.load(path, sr=sr, mono=True) return y, sr def _to_db(rms: np.ndarray): ref = np.maximum(np.max(rms), 1e-10) return 20.0 * np.log10(np.maximum(rms, 1e-10) / ref) def _fmt_ts(sec: float) -> str: if sec < 0: sec = 0.0 h = int(sec // 3600) m = int((sec % 3600) // 60) s = int(sec % 60) ms = int(round((sec - math.floor(sec)) * 1000.0)) return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}" def _extract_word_stream(chunks) -> List[Tuple[str, float, float]]: out = [] if not chunks: return out for ch in chunks: txt = (ch.get("text") or "").strip() ts = ch.get("timestamp", ch.get("timestamps", None)) if ts is None: s = ch.get("start", ch.get("time_start", None)) e = ch.get("end", ch.get("time_end", None)) if s is not None and e is not None and txt: s = float(s); e = float(e) if e < s: e = s out.append((txt, s, e)) continue if isinstance(ts, (list, tuple)) and len(ts) == 2 and txt: s = float(ts[0] or 0.0); e = float(ts[1] or 0.0) if e < s: e = s out.append((txt, s, e)) return out def _detect_silence_cuts( y: np.ndarray, sr: int, silence_min_len: float, db_drop: float, pctl_floor: float, frame_len_ms: int, hop_len_ms: int, ): frame_len = max(256, int(sr * frame_len_ms / 1000)) hop_len = max( 64, int(sr * hop_len_ms / 1000)) rms = librosa.feature.rms(y=y, frame_length=frame_len, hop_length=hop_len, center=True)[0] rms_db = _to_db(rms) if _HAS_SCIPY: rms_db = median_filter(rms_db, size=5) max_db = float(np.max(rms_db)) floor_db = float(np.percentile(rms_db, pctl_floor)) thr = max(max_db - db_drop, floor_db) low = rms_db <= thr cut_times = [] n = len(low) i = 0 min_frames = max(1, int(silence_min_len * sr / hop_len)) while i < n: if not low[i]: i += 1; continue j = i + 1 while j < n and low[j]: j += 1 if (j - i) >= min_frames: local = rms_db[i:j] k = int(np.argmin(local)) best = i + k cut_times.append(best * hop_len / sr) i = j total = float(len(y) / sr) cut_times = sorted(set(t for t in cut_times if 0.05 <= t <= total - 0.05)) return cut_times, total def _snap_to_word_bounds(cuts: List[float], words: List[Tuple[str, float, float]], max_dist=0.25): if not cuts or not words: return cuts bounds = sorted({b for _, s, e in words for b in (s, e)}) snapped = [] for t in cuts: idx = min(range(len(bounds)), key=lambda i: abs(bounds[i]-t)) snapped.append(bounds[idx] if abs(bounds[idx]-t) <= max_dist else t) snapped = sorted(set(snapped)) out = [] for t in snapped: if not out or (t - out[-1]) >= 0.12: out.append(t) return out def _segment(words: List[Tuple[str,float,float]], cuts: List[float], total: float, min_seg: float): if not words: return [(0.0, total, "")] bnds = [0.0] + [t for t in cuts if 0.0 < t < total] + [total] segs = [] wi, W = 0, len(words) for i in range(len(bnds)-1): L, R = bnds[i], bnds[i+1] texts, starts, ends = [], [], [] while wi < W and words[wi][2] <= L: wi += 1 wj = wi while wj < W and words[wj][1] < R: txt, s, e = words[wj] if e > L and s < R: texts.append(txt); starts.append(s); ends.append(e) wj += 1 if texts: st, en = max(min(starts), L), min(max(ends), R) segs.append([float(st), float(en), " ".join(texts).strip()]) elif (R - L) >= max(0.25, min_seg * 0.5): segs.append([L, R, ""]) def has_punc(t): return any(p in t for p in ",。!?,.!?;;::") i = 0 while i < len(segs): st, en, tx = segs[i] if (en - st) < min_seg and len(segs) > 1: cand = [] if i + 1 < len(segs): cand.append(i + 1) if i - 1 >= 0: cand.append(i - 1) cand.sort(key=lambda j: (not has_punc(segs[j][2]), abs(j - i))) t = cand[0] nst, nen = min(segs[t][0], st), max(segs[t][1], en) ntx = (" ".join([segs[t][2], tx]) if t < i else " ".join([tx, segs[t][2]])).strip() keep, drop = (t, i) if t < i else (i, t) segs[keep] = [nst, nen, ntx] del segs[drop] i = max(0, keep - 1); continue i += 1 return [(st, en, tx.strip()) for st, en, tx in segs if (en - st) >= 0.12] def _build_srt(segs: List[Tuple[float,float,str]]) -> str: lines = [] for idx, (st, en, tx) in enumerate(segs, start=1): lines.append(str(idx)) lines.append(f"{_fmt_ts(st)} --> {_fmt_ts(en)}") lines.append(tx) lines.append("") return "\n".join(lines).strip() + "\n" # ================== 推理核心(放在 GPU 上执行) ================== @spaces.GPU # 关键:ZeroGPU 运行入口(按钮点击会调用它) def transcribe_and_split( audio_path: str, silence_min_len: float = DEF_SILENCE_MIN_LEN, db_drop: float = DEF_DB_DROP, pctl_floor: float = DEF_PCTL_FLOOR, min_seg_dur: float = DEF_MIN_SEG_DUR, frame_len_ms: int = DEF_FRAME_LEN_MS, hop_len_ms: int = DEF_HOP_LEN_MS, cut_offset_sec: float = DEF_CUT_OFFSET_SEC, ): if not audio_path: raise gr.Error("请先上传或录制音频。") try: if os.path.getsize(audio_path) / (1024*1024) > FILE_LIMIT_MB: raise gr.Error(f"文件过大,超过 {FILE_LIMIT_MB} MB。") except Exception: pass asr = _get_asr() # 在 GPU 上首次创建 result = asr( audio_path, return_timestamps="word", chunk_length_s=DEF_CHUNK_LEN_S, stride_length_s=DEF_STRIDE_LEN_S, batch_size=BATCH_SIZE, ) text = (result.get("text") or "").strip() words = _extract_word_stream(result.get("chunks") or []) y, sr = _load_audio(audio_path, sr=SR_TARGET) cuts, total = _detect_silence_cuts( y, sr, silence_min_len=silence_min_len, db_drop=db_drop, pctl_floor=pctl_floor, frame_len_ms=frame_len_ms, hop_len_ms=hop_len_ms, ) if abs(cut_offset_sec) > 1e-6: cuts = [max(0.0, min(total, t + cut_offset_sec)) for t in cuts] cuts = _snap_to_word_bounds(cuts, words, max_dist=0.25) segs = _segment(words, cuts, total, min_seg_dur) if not segs: segs = [(0.0, total, text)] srt = _build_srt(segs) tmpf = tempfile.NamedTemporaryFile(delete=False, suffix=".srt") tmpf.write(srt.encode("utf-8")); tmpf.flush(); tmpf.close() return srt, tmpf.name # 让启动检查看到 GPU 入口(可选,不调用也行) @spaces.GPU def gpu_warmup(): return "ok" # ================== UI ================== with gr.Blocks(title="Whisper Large V3 · 智能切分 SRT", theme=gr.themes.Soft()) as demo: gr.Markdown("### 🎧 Whisper Large V3 · 更稳的 SRT 切分\n" "- 词级时间戳 + 能量最低点切分 + 词边界吸附\n" "- 片段过短自动合并,SRT 含序号行\n") audio = gr.Audio(sources=["upload", "microphone"], type="filepath", label="音频(上传或录制)") with gr.Accordion("高级参数", open=False): with gr.Row(): silence_min_len = gr.Slider(0.1, 1.0, value=DEF_SILENCE_MIN_LEN, step=0.05, label="静音最短时长 (s)") db_drop = gr.Slider(10, 40, value=DEF_DB_DROP, step=1.0, label="相对峰值下落 (dB)") pctl_floor = gr.Slider(0, 50, value=DEF_PCTL_FLOOR, step=1.0, label="能量分位下限 (dB)") with gr.Row(): min_seg_dur = gr.Slider(0.3, 3.0, value=DEF_MIN_SEG_DUR, step=0.05, label="最短片段时长 (s)") frame_len_ms = gr.Slider(10, 50, value=DEF_FRAME_LEN_MS, step=1, label="帧长 (ms)") hop_len_ms = gr.Slider(5, 25, value=DEF_HOP_LEN_MS, step=1, label="帧移 (ms)") cut_offset_sec = gr.Slider(-0.20, 0.20, value=DEF_CUT_OFFSET_SEC, step=0.01, label="切分整体偏移 (s)") btn = gr.Button("开始识别并生成 SRT", variant="primary") srt_preview = gr.Textbox(lines=16, label="SRT 预览", show_copy_button=True) srt_file = gr.File(label="下载 SRT 文件", file_count="single") btn.click( fn=transcribe_and_split, # 注意:绑定的是 @spaces.GPU 函数 inputs=[audio, silence_min_len, db_drop, pctl_floor, min_seg_dur, frame_len_ms, hop_len_ms, cut_offset_sec], outputs=[srt_preview, srt_file], ) if __name__ == "__main__": demo.launch()