datxy's picture
Update app.py
8935a60 verified
import os, io, math, tempfile
from typing import List, Tuple
import numpy as np
import gradio as gr
import librosa
import torch
try:
from scipy.ndimage import median_filter
_HAS_SCIPY = True
except Exception:
_HAS_SCIPY = False
from transformers import pipeline
import spaces # 关键:用于 ZeroGPU
# ================== 默认参数 ==================
MODEL_NAME = "openai/whisper-large-v3"
BATCH_SIZE = 8
FILE_LIMIT_MB = 1000
DEF_SILENCE_MIN_LEN = 0.45
DEF_DB_DROP = 25.0
DEF_PCTL_FLOOR = 20.0
DEF_MIN_SEG_DUR = 1.00
DEF_FRAME_LEN_MS = 25
DEF_HOP_LEN_MS = 10
DEF_CUT_OFFSET_SEC = 0.00
DEF_CHUNK_LEN_S = 20
DEF_STRIDE_LEN_S = 2
SR_TARGET = 16000
# ================== 全局懒加载 ==================
_ASR = None
_ASR_DEVICE = None
_ASR_DTYPE = None
def _pick_device_dtype():
if torch.cuda.is_available():
return 0, torch.float16
elif torch.backends.mps.is_available():
return "mps", torch.float16
else:
return -1, torch.float32
def _get_asr():
"""
在 ZeroGPU 下必须在 @spaces.GPU 修饰的函数内首次调用,才能拿到 cuda。
CPU/常规 GPU 也兼容。
"""
global _ASR, _ASR_DEVICE, _ASR_DTYPE
dev, dt = _pick_device_dtype()
if _ASR is None or _ASR_DEVICE != dev:
_ASR = pipeline(
task="automatic-speech-recognition",
model=MODEL_NAME,
device=dev,
torch_dtype=dt,
return_timestamps="word",
chunk_length_s=DEF_CHUNK_LEN_S,
stride_length_s=DEF_STRIDE_LEN_S,
ignore_warning=True,
)
_ASR_DEVICE, _ASR_DTYPE = dev, dt
print(f"[ASR] Initialized on device={dev} dtype={dt}")
return _ASR
# ================== 音频 & 工具 ==================
def _load_audio(path: str, sr: int = SR_TARGET):
y, sr = librosa.load(path, sr=sr, mono=True)
return y, sr
def _to_db(rms: np.ndarray):
ref = np.maximum(np.max(rms), 1e-10)
return 20.0 * np.log10(np.maximum(rms, 1e-10) / ref)
def _fmt_ts(sec: float) -> str:
if sec < 0: sec = 0.0
h = int(sec // 3600)
m = int((sec % 3600) // 60)
s = int(sec % 60)
ms = int(round((sec - math.floor(sec)) * 1000.0))
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
def _extract_word_stream(chunks) -> List[Tuple[str, float, float]]:
out = []
if not chunks:
return out
for ch in chunks:
txt = (ch.get("text") or "").strip()
ts = ch.get("timestamp", ch.get("timestamps", None))
if ts is None:
s = ch.get("start", ch.get("time_start", None))
e = ch.get("end", ch.get("time_end", None))
if s is not None and e is not None and txt:
s = float(s); e = float(e)
if e < s: e = s
out.append((txt, s, e))
continue
if isinstance(ts, (list, tuple)) and len(ts) == 2 and txt:
s = float(ts[0] or 0.0); e = float(ts[1] or 0.0)
if e < s: e = s
out.append((txt, s, e))
return out
def _detect_silence_cuts(
y: np.ndarray,
sr: int,
silence_min_len: float,
db_drop: float,
pctl_floor: float,
frame_len_ms: int,
hop_len_ms: int,
):
frame_len = max(256, int(sr * frame_len_ms / 1000))
hop_len = max( 64, int(sr * hop_len_ms / 1000))
rms = librosa.feature.rms(y=y, frame_length=frame_len, hop_length=hop_len, center=True)[0]
rms_db = _to_db(rms)
if _HAS_SCIPY:
rms_db = median_filter(rms_db, size=5)
max_db = float(np.max(rms_db))
floor_db = float(np.percentile(rms_db, pctl_floor))
thr = max(max_db - db_drop, floor_db)
low = rms_db <= thr
cut_times = []
n = len(low)
i = 0
min_frames = max(1, int(silence_min_len * sr / hop_len))
while i < n:
if not low[i]:
i += 1; continue
j = i + 1
while j < n and low[j]:
j += 1
if (j - i) >= min_frames:
local = rms_db[i:j]
k = int(np.argmin(local))
best = i + k
cut_times.append(best * hop_len / sr)
i = j
total = float(len(y) / sr)
cut_times = sorted(set(t for t in cut_times if 0.05 <= t <= total - 0.05))
return cut_times, total
def _snap_to_word_bounds(cuts: List[float], words: List[Tuple[str, float, float]], max_dist=0.25):
if not cuts or not words: return cuts
bounds = sorted({b for _, s, e in words for b in (s, e)})
snapped = []
for t in cuts:
idx = min(range(len(bounds)), key=lambda i: abs(bounds[i]-t))
snapped.append(bounds[idx] if abs(bounds[idx]-t) <= max_dist else t)
snapped = sorted(set(snapped))
out = []
for t in snapped:
if not out or (t - out[-1]) >= 0.12:
out.append(t)
return out
def _segment(words: List[Tuple[str,float,float]], cuts: List[float], total: float, min_seg: float):
if not words:
return [(0.0, total, "")]
bnds = [0.0] + [t for t in cuts if 0.0 < t < total] + [total]
segs = []
wi, W = 0, len(words)
for i in range(len(bnds)-1):
L, R = bnds[i], bnds[i+1]
texts, starts, ends = [], [], []
while wi < W and words[wi][2] <= L:
wi += 1
wj = wi
while wj < W and words[wj][1] < R:
txt, s, e = words[wj]
if e > L and s < R:
texts.append(txt); starts.append(s); ends.append(e)
wj += 1
if texts:
st, en = max(min(starts), L), min(max(ends), R)
segs.append([float(st), float(en), " ".join(texts).strip()])
elif (R - L) >= max(0.25, min_seg * 0.5):
segs.append([L, R, ""])
def has_punc(t): return any(p in t for p in ",。!?,.!?;;::")
i = 0
while i < len(segs):
st, en, tx = segs[i]
if (en - st) < min_seg and len(segs) > 1:
cand = []
if i + 1 < len(segs): cand.append(i + 1)
if i - 1 >= 0: cand.append(i - 1)
cand.sort(key=lambda j: (not has_punc(segs[j][2]), abs(j - i)))
t = cand[0]
nst, nen = min(segs[t][0], st), max(segs[t][1], en)
ntx = (" ".join([segs[t][2], tx]) if t < i else " ".join([tx, segs[t][2]])).strip()
keep, drop = (t, i) if t < i else (i, t)
segs[keep] = [nst, nen, ntx]
del segs[drop]
i = max(0, keep - 1); continue
i += 1
return [(st, en, tx.strip()) for st, en, tx in segs if (en - st) >= 0.12]
def _build_srt(segs: List[Tuple[float,float,str]]) -> str:
lines = []
for idx, (st, en, tx) in enumerate(segs, start=1):
lines.append(str(idx))
lines.append(f"{_fmt_ts(st)} --> {_fmt_ts(en)}")
lines.append(tx)
lines.append("")
return "\n".join(lines).strip() + "\n"
# ================== 推理核心(放在 GPU 上执行) ==================
@spaces.GPU # 关键:ZeroGPU 运行入口(按钮点击会调用它)
def transcribe_and_split(
audio_path: str,
silence_min_len: float = DEF_SILENCE_MIN_LEN,
db_drop: float = DEF_DB_DROP,
pctl_floor: float = DEF_PCTL_FLOOR,
min_seg_dur: float = DEF_MIN_SEG_DUR,
frame_len_ms: int = DEF_FRAME_LEN_MS,
hop_len_ms: int = DEF_HOP_LEN_MS,
cut_offset_sec: float = DEF_CUT_OFFSET_SEC,
):
if not audio_path:
raise gr.Error("请先上传或录制音频。")
try:
if os.path.getsize(audio_path) / (1024*1024) > FILE_LIMIT_MB:
raise gr.Error(f"文件过大,超过 {FILE_LIMIT_MB} MB。")
except Exception:
pass
asr = _get_asr() # 在 GPU 上首次创建
result = asr(
audio_path,
return_timestamps="word",
chunk_length_s=DEF_CHUNK_LEN_S,
stride_length_s=DEF_STRIDE_LEN_S,
batch_size=BATCH_SIZE,
)
text = (result.get("text") or "").strip()
words = _extract_word_stream(result.get("chunks") or [])
y, sr = _load_audio(audio_path, sr=SR_TARGET)
cuts, total = _detect_silence_cuts(
y, sr,
silence_min_len=silence_min_len,
db_drop=db_drop,
pctl_floor=pctl_floor,
frame_len_ms=frame_len_ms,
hop_len_ms=hop_len_ms,
)
if abs(cut_offset_sec) > 1e-6:
cuts = [max(0.0, min(total, t + cut_offset_sec)) for t in cuts]
cuts = _snap_to_word_bounds(cuts, words, max_dist=0.25)
segs = _segment(words, cuts, total, min_seg_dur)
if not segs:
segs = [(0.0, total, text)]
srt = _build_srt(segs)
tmpf = tempfile.NamedTemporaryFile(delete=False, suffix=".srt")
tmpf.write(srt.encode("utf-8")); tmpf.flush(); tmpf.close()
return srt, tmpf.name
# 让启动检查看到 GPU 入口(可选,不调用也行)
@spaces.GPU
def gpu_warmup():
return "ok"
# ================== UI ==================
with gr.Blocks(title="Whisper Large V3 · 智能切分 SRT", theme=gr.themes.Soft()) as demo:
gr.Markdown("### 🎧 Whisper Large V3 · 更稳的 SRT 切分\n"
"- 词级时间戳 + 能量最低点切分 + 词边界吸附\n"
"- 片段过短自动合并,SRT 含序号行\n")
audio = gr.Audio(sources=["upload", "microphone"], type="filepath", label="音频(上传或录制)")
with gr.Accordion("高级参数", open=False):
with gr.Row():
silence_min_len = gr.Slider(0.1, 1.0, value=DEF_SILENCE_MIN_LEN, step=0.05, label="静音最短时长 (s)")
db_drop = gr.Slider(10, 40, value=DEF_DB_DROP, step=1.0, label="相对峰值下落 (dB)")
pctl_floor = gr.Slider(0, 50, value=DEF_PCTL_FLOOR, step=1.0, label="能量分位下限 (dB)")
with gr.Row():
min_seg_dur = gr.Slider(0.3, 3.0, value=DEF_MIN_SEG_DUR, step=0.05, label="最短片段时长 (s)")
frame_len_ms = gr.Slider(10, 50, value=DEF_FRAME_LEN_MS, step=1, label="帧长 (ms)")
hop_len_ms = gr.Slider(5, 25, value=DEF_HOP_LEN_MS, step=1, label="帧移 (ms)")
cut_offset_sec = gr.Slider(-0.20, 0.20, value=DEF_CUT_OFFSET_SEC, step=0.01, label="切分整体偏移 (s)")
btn = gr.Button("开始识别并生成 SRT", variant="primary")
srt_preview = gr.Textbox(lines=16, label="SRT 预览", show_copy_button=True)
srt_file = gr.File(label="下载 SRT 文件", file_count="single")
btn.click(
fn=transcribe_and_split, # 注意:绑定的是 @spaces.GPU 函数
inputs=[audio, silence_min_len, db_drop, pctl_floor, min_seg_dur, frame_len_ms, hop_len_ms, cut_offset_sec],
outputs=[srt_preview, srt_file],
)
if __name__ == "__main__":
demo.launch()