Spaces:
Running
on
Zero
Running
on
Zero
import os, io, math, tempfile | |
from typing import List, Tuple | |
import numpy as np | |
import gradio as gr | |
import librosa | |
import torch | |
try: | |
from scipy.ndimage import median_filter | |
_HAS_SCIPY = True | |
except Exception: | |
_HAS_SCIPY = False | |
from transformers import pipeline | |
import spaces # 关键:用于 ZeroGPU | |
# ================== 默认参数 ================== | |
MODEL_NAME = "openai/whisper-large-v3" | |
BATCH_SIZE = 8 | |
FILE_LIMIT_MB = 1000 | |
DEF_SILENCE_MIN_LEN = 0.45 | |
DEF_DB_DROP = 25.0 | |
DEF_PCTL_FLOOR = 20.0 | |
DEF_MIN_SEG_DUR = 1.00 | |
DEF_FRAME_LEN_MS = 25 | |
DEF_HOP_LEN_MS = 10 | |
DEF_CUT_OFFSET_SEC = 0.00 | |
DEF_CHUNK_LEN_S = 20 | |
DEF_STRIDE_LEN_S = 2 | |
SR_TARGET = 16000 | |
# ================== 全局懒加载 ================== | |
_ASR = None | |
_ASR_DEVICE = None | |
_ASR_DTYPE = None | |
def _pick_device_dtype(): | |
if torch.cuda.is_available(): | |
return 0, torch.float16 | |
elif torch.backends.mps.is_available(): | |
return "mps", torch.float16 | |
else: | |
return -1, torch.float32 | |
def _get_asr(): | |
""" | |
在 ZeroGPU 下必须在 @spaces.GPU 修饰的函数内首次调用,才能拿到 cuda。 | |
CPU/常规 GPU 也兼容。 | |
""" | |
global _ASR, _ASR_DEVICE, _ASR_DTYPE | |
dev, dt = _pick_device_dtype() | |
if _ASR is None or _ASR_DEVICE != dev: | |
_ASR = pipeline( | |
task="automatic-speech-recognition", | |
model=MODEL_NAME, | |
device=dev, | |
torch_dtype=dt, | |
return_timestamps="word", | |
chunk_length_s=DEF_CHUNK_LEN_S, | |
stride_length_s=DEF_STRIDE_LEN_S, | |
ignore_warning=True, | |
) | |
_ASR_DEVICE, _ASR_DTYPE = dev, dt | |
print(f"[ASR] Initialized on device={dev} dtype={dt}") | |
return _ASR | |
# ================== 音频 & 工具 ================== | |
def _load_audio(path: str, sr: int = SR_TARGET): | |
y, sr = librosa.load(path, sr=sr, mono=True) | |
return y, sr | |
def _to_db(rms: np.ndarray): | |
ref = np.maximum(np.max(rms), 1e-10) | |
return 20.0 * np.log10(np.maximum(rms, 1e-10) / ref) | |
def _fmt_ts(sec: float) -> str: | |
if sec < 0: sec = 0.0 | |
h = int(sec // 3600) | |
m = int((sec % 3600) // 60) | |
s = int(sec % 60) | |
ms = int(round((sec - math.floor(sec)) * 1000.0)) | |
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}" | |
def _extract_word_stream(chunks) -> List[Tuple[str, float, float]]: | |
out = [] | |
if not chunks: | |
return out | |
for ch in chunks: | |
txt = (ch.get("text") or "").strip() | |
ts = ch.get("timestamp", ch.get("timestamps", None)) | |
if ts is None: | |
s = ch.get("start", ch.get("time_start", None)) | |
e = ch.get("end", ch.get("time_end", None)) | |
if s is not None and e is not None and txt: | |
s = float(s); e = float(e) | |
if e < s: e = s | |
out.append((txt, s, e)) | |
continue | |
if isinstance(ts, (list, tuple)) and len(ts) == 2 and txt: | |
s = float(ts[0] or 0.0); e = float(ts[1] or 0.0) | |
if e < s: e = s | |
out.append((txt, s, e)) | |
return out | |
def _detect_silence_cuts( | |
y: np.ndarray, | |
sr: int, | |
silence_min_len: float, | |
db_drop: float, | |
pctl_floor: float, | |
frame_len_ms: int, | |
hop_len_ms: int, | |
): | |
frame_len = max(256, int(sr * frame_len_ms / 1000)) | |
hop_len = max( 64, int(sr * hop_len_ms / 1000)) | |
rms = librosa.feature.rms(y=y, frame_length=frame_len, hop_length=hop_len, center=True)[0] | |
rms_db = _to_db(rms) | |
if _HAS_SCIPY: | |
rms_db = median_filter(rms_db, size=5) | |
max_db = float(np.max(rms_db)) | |
floor_db = float(np.percentile(rms_db, pctl_floor)) | |
thr = max(max_db - db_drop, floor_db) | |
low = rms_db <= thr | |
cut_times = [] | |
n = len(low) | |
i = 0 | |
min_frames = max(1, int(silence_min_len * sr / hop_len)) | |
while i < n: | |
if not low[i]: | |
i += 1; continue | |
j = i + 1 | |
while j < n and low[j]: | |
j += 1 | |
if (j - i) >= min_frames: | |
local = rms_db[i:j] | |
k = int(np.argmin(local)) | |
best = i + k | |
cut_times.append(best * hop_len / sr) | |
i = j | |
total = float(len(y) / sr) | |
cut_times = sorted(set(t for t in cut_times if 0.05 <= t <= total - 0.05)) | |
return cut_times, total | |
def _snap_to_word_bounds(cuts: List[float], words: List[Tuple[str, float, float]], max_dist=0.25): | |
if not cuts or not words: return cuts | |
bounds = sorted({b for _, s, e in words for b in (s, e)}) | |
snapped = [] | |
for t in cuts: | |
idx = min(range(len(bounds)), key=lambda i: abs(bounds[i]-t)) | |
snapped.append(bounds[idx] if abs(bounds[idx]-t) <= max_dist else t) | |
snapped = sorted(set(snapped)) | |
out = [] | |
for t in snapped: | |
if not out or (t - out[-1]) >= 0.12: | |
out.append(t) | |
return out | |
def _segment(words: List[Tuple[str,float,float]], cuts: List[float], total: float, min_seg: float): | |
if not words: | |
return [(0.0, total, "")] | |
bnds = [0.0] + [t for t in cuts if 0.0 < t < total] + [total] | |
segs = [] | |
wi, W = 0, len(words) | |
for i in range(len(bnds)-1): | |
L, R = bnds[i], bnds[i+1] | |
texts, starts, ends = [], [], [] | |
while wi < W and words[wi][2] <= L: | |
wi += 1 | |
wj = wi | |
while wj < W and words[wj][1] < R: | |
txt, s, e = words[wj] | |
if e > L and s < R: | |
texts.append(txt); starts.append(s); ends.append(e) | |
wj += 1 | |
if texts: | |
st, en = max(min(starts), L), min(max(ends), R) | |
segs.append([float(st), float(en), " ".join(texts).strip()]) | |
elif (R - L) >= max(0.25, min_seg * 0.5): | |
segs.append([L, R, ""]) | |
def has_punc(t): return any(p in t for p in ",。!?,.!?;;::") | |
i = 0 | |
while i < len(segs): | |
st, en, tx = segs[i] | |
if (en - st) < min_seg and len(segs) > 1: | |
cand = [] | |
if i + 1 < len(segs): cand.append(i + 1) | |
if i - 1 >= 0: cand.append(i - 1) | |
cand.sort(key=lambda j: (not has_punc(segs[j][2]), abs(j - i))) | |
t = cand[0] | |
nst, nen = min(segs[t][0], st), max(segs[t][1], en) | |
ntx = (" ".join([segs[t][2], tx]) if t < i else " ".join([tx, segs[t][2]])).strip() | |
keep, drop = (t, i) if t < i else (i, t) | |
segs[keep] = [nst, nen, ntx] | |
del segs[drop] | |
i = max(0, keep - 1); continue | |
i += 1 | |
return [(st, en, tx.strip()) for st, en, tx in segs if (en - st) >= 0.12] | |
def _build_srt(segs: List[Tuple[float,float,str]]) -> str: | |
lines = [] | |
for idx, (st, en, tx) in enumerate(segs, start=1): | |
lines.append(str(idx)) | |
lines.append(f"{_fmt_ts(st)} --> {_fmt_ts(en)}") | |
lines.append(tx) | |
lines.append("") | |
return "\n".join(lines).strip() + "\n" | |
# ================== 推理核心(放在 GPU 上执行) ================== | |
# 关键:ZeroGPU 运行入口(按钮点击会调用它) | |
def transcribe_and_split( | |
audio_path: str, | |
silence_min_len: float = DEF_SILENCE_MIN_LEN, | |
db_drop: float = DEF_DB_DROP, | |
pctl_floor: float = DEF_PCTL_FLOOR, | |
min_seg_dur: float = DEF_MIN_SEG_DUR, | |
frame_len_ms: int = DEF_FRAME_LEN_MS, | |
hop_len_ms: int = DEF_HOP_LEN_MS, | |
cut_offset_sec: float = DEF_CUT_OFFSET_SEC, | |
): | |
if not audio_path: | |
raise gr.Error("请先上传或录制音频。") | |
try: | |
if os.path.getsize(audio_path) / (1024*1024) > FILE_LIMIT_MB: | |
raise gr.Error(f"文件过大,超过 {FILE_LIMIT_MB} MB。") | |
except Exception: | |
pass | |
asr = _get_asr() # 在 GPU 上首次创建 | |
result = asr( | |
audio_path, | |
return_timestamps="word", | |
chunk_length_s=DEF_CHUNK_LEN_S, | |
stride_length_s=DEF_STRIDE_LEN_S, | |
batch_size=BATCH_SIZE, | |
) | |
text = (result.get("text") or "").strip() | |
words = _extract_word_stream(result.get("chunks") or []) | |
y, sr = _load_audio(audio_path, sr=SR_TARGET) | |
cuts, total = _detect_silence_cuts( | |
y, sr, | |
silence_min_len=silence_min_len, | |
db_drop=db_drop, | |
pctl_floor=pctl_floor, | |
frame_len_ms=frame_len_ms, | |
hop_len_ms=hop_len_ms, | |
) | |
if abs(cut_offset_sec) > 1e-6: | |
cuts = [max(0.0, min(total, t + cut_offset_sec)) for t in cuts] | |
cuts = _snap_to_word_bounds(cuts, words, max_dist=0.25) | |
segs = _segment(words, cuts, total, min_seg_dur) | |
if not segs: | |
segs = [(0.0, total, text)] | |
srt = _build_srt(segs) | |
tmpf = tempfile.NamedTemporaryFile(delete=False, suffix=".srt") | |
tmpf.write(srt.encode("utf-8")); tmpf.flush(); tmpf.close() | |
return srt, tmpf.name | |
# 让启动检查看到 GPU 入口(可选,不调用也行) | |
def gpu_warmup(): | |
return "ok" | |
# ================== UI ================== | |
with gr.Blocks(title="Whisper Large V3 · 智能切分 SRT", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("### 🎧 Whisper Large V3 · 更稳的 SRT 切分\n" | |
"- 词级时间戳 + 能量最低点切分 + 词边界吸附\n" | |
"- 片段过短自动合并,SRT 含序号行\n") | |
audio = gr.Audio(sources=["upload", "microphone"], type="filepath", label="音频(上传或录制)") | |
with gr.Accordion("高级参数", open=False): | |
with gr.Row(): | |
silence_min_len = gr.Slider(0.1, 1.0, value=DEF_SILENCE_MIN_LEN, step=0.05, label="静音最短时长 (s)") | |
db_drop = gr.Slider(10, 40, value=DEF_DB_DROP, step=1.0, label="相对峰值下落 (dB)") | |
pctl_floor = gr.Slider(0, 50, value=DEF_PCTL_FLOOR, step=1.0, label="能量分位下限 (dB)") | |
with gr.Row(): | |
min_seg_dur = gr.Slider(0.3, 3.0, value=DEF_MIN_SEG_DUR, step=0.05, label="最短片段时长 (s)") | |
frame_len_ms = gr.Slider(10, 50, value=DEF_FRAME_LEN_MS, step=1, label="帧长 (ms)") | |
hop_len_ms = gr.Slider(5, 25, value=DEF_HOP_LEN_MS, step=1, label="帧移 (ms)") | |
cut_offset_sec = gr.Slider(-0.20, 0.20, value=DEF_CUT_OFFSET_SEC, step=0.01, label="切分整体偏移 (s)") | |
btn = gr.Button("开始识别并生成 SRT", variant="primary") | |
srt_preview = gr.Textbox(lines=16, label="SRT 预览", show_copy_button=True) | |
srt_file = gr.File(label="下载 SRT 文件", file_count="single") | |
btn.click( | |
fn=transcribe_and_split, # 注意:绑定的是 @spaces.GPU 函数 | |
inputs=[audio, silence_min_len, db_drop, pctl_floor, min_seg_dur, frame_len_ms, hop_len_ms, cut_offset_sec], | |
outputs=[srt_preview, srt_file], | |
) | |
if __name__ == "__main__": | |
demo.launch() |