Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -2,32 +2,30 @@
|
|
2 |
import spaces
|
3 |
import torch
|
4 |
import gradio as gr
|
|
|
5 |
import tempfile
|
6 |
import os
|
7 |
|
8 |
-
#
|
9 |
-
|
10 |
-
try:
|
11 |
-
if USE_FASTER_WHISPER:
|
12 |
-
from faster_whisper import WhisperModel # type: ignore
|
13 |
-
_HAS_FW = True
|
14 |
-
except Exception:
|
15 |
-
_HAS_FW = False
|
16 |
-
|
17 |
-
# ====== 可调参数 ======
|
18 |
-
ASR_MODEL = "openai/whisper-large-v3" # Transformers 管线备用模型
|
19 |
BATCH_SIZE = 8
|
20 |
FILE_LIMIT_MB = 1000
|
21 |
-
MAX_SEG_DUR = 6.0
|
22 |
-
MAX_SEG_CHARS = 28
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
-
#
|
31 |
def _srt_timestamp(seconds: float | None) -> str:
|
32 |
if seconds is None or seconds < 0:
|
33 |
seconds = 0.0
|
@@ -37,25 +35,9 @@ def _srt_timestamp(seconds: float | None) -> str:
|
|
37 |
s, ms = divmod(ms, 1000)
|
38 |
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
|
39 |
|
40 |
-
#
|
41 |
-
_STRONG = "。!?.!?"
|
42 |
-
def _ensure_sentence_punct(s: str) -> str:
|
43 |
-
s = s.strip()
|
44 |
-
if not s:
|
45 |
-
return s
|
46 |
-
if s[-1] not in _STRONG:
|
47 |
-
s += "。"
|
48 |
-
return s
|
49 |
-
|
50 |
-
def _maybe_add_comma(s: str) -> str:
|
51 |
-
s = s.rstrip()
|
52 |
-
if s and s[-1] not in _STRONG + ",,、;;":
|
53 |
-
s += ","
|
54 |
-
return s
|
55 |
-
|
56 |
-
# ====== 文本切分(标点优先,其次按长度兜底)=====
|
57 |
def _split_text_units(txt: str, max_seg_chars: int) -> list[str]:
|
58 |
-
strong =
|
59 |
units, cur = [], []
|
60 |
for ch in txt:
|
61 |
cur.append(ch)
|
@@ -70,16 +52,22 @@ def _split_text_units(txt: str, max_seg_chars: int) -> list[str]:
|
|
70 |
if len(u) <= max_seg_chars:
|
71 |
refined.append(u)
|
72 |
else:
|
73 |
-
#
|
74 |
for i in range(0, len(u), max_seg_chars):
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
return [x for x in refined if x]
|
77 |
|
78 |
-
#
|
79 |
-
def
|
|
|
|
|
80 |
lines = []
|
81 |
-
prev_end = None
|
82 |
-
|
83 |
for ch in chunks or []:
|
84 |
text = (ch.get("text") or "").strip()
|
85 |
if not text:
|
@@ -91,13 +79,6 @@ def _chunks_to_srt_no_number(chunks, max_seg_dur=MAX_SEG_DUR, max_seg_chars=MAX_
|
|
91 |
else:
|
92 |
c_start, c_end = 0.0, 2.0
|
93 |
|
94 |
-
# 根据与上一段的“停顿”补逗号/句号(软提示,真正的标点在行尾处理)
|
95 |
-
if prev_end is not None:
|
96 |
-
gap = max(c_start - prev_end, 0.0)
|
97 |
-
# 我们不直接把标点写进上一行文本,而是作为划分参考
|
98 |
-
# 具体标点在每行最终输出前处理(见下方)
|
99 |
-
prev_end = c_end
|
100 |
-
|
101 |
units = _split_text_units(text, max_seg_chars)
|
102 |
if not units:
|
103 |
units = [text]
|
@@ -107,100 +88,32 @@ def _chunks_to_srt_no_number(chunks, max_seg_dur=MAX_SEG_DUR, max_seg_chars=MAX_
|
|
107 |
|
108 |
cur_t = c_start
|
109 |
for u in units:
|
110 |
-
alloc =
|
111 |
-
|
112 |
-
# 若超长,继续细分为不超过 max_seg_chars 的片,并均分时长
|
113 |
if alloc <= max_seg_dur:
|
114 |
pieces = [u]
|
115 |
per = alloc
|
116 |
else:
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
|
|
|
|
|
|
122 |
st = cur_t
|
123 |
en = st + per
|
124 |
-
|
125 |
-
frag = p.strip()
|
126 |
-
|
127 |
-
# 行尾自动补标点:更自然
|
128 |
-
# 规则:
|
129 |
-
# 1) 若该小片接近一句末尾(片段较长 或 达到 max_seg_chars)→ 句号
|
130 |
-
# 2) 否则可能是轻停顿 → 逗号
|
131 |
-
# 3) 若已有强标点,保持不动
|
132 |
-
if frag and frag[-1] not in _STRONG:
|
133 |
-
# 按时长和长度推断语气
|
134 |
-
if per >= PAUSE_LONG or len(frag) >= max_seg_chars * 0.9:
|
135 |
-
frag = _ensure_sentence_punct(frag)
|
136 |
-
elif per >= PAUSE_SHORT:
|
137 |
-
frag = _maybe_add_comma(frag)
|
138 |
-
# else: 很短的片段不强制加标点(避免过密)
|
139 |
-
|
140 |
lines.append(f"{_srt_timestamp(st)} --> {_srt_timestamp(en)}")
|
141 |
-
lines.append(
|
142 |
lines.append("")
|
143 |
cur_t = en
|
144 |
|
145 |
return "\n".join(lines).strip() + ("\n" if lines else "")
|
146 |
|
147 |
-
#
|
148 |
-
def _fw_transcribe_to_chunks(audio_path: str):
|
149 |
-
"""
|
150 |
-
使用 faster-whisper + vad_filter=True 返回“类 chunk”结构:
|
151 |
-
[{'text': '...', 'timestamp': [start, end]}...]
|
152 |
-
"""
|
153 |
-
# 模型建议设为 medium/large-v3;这里默认 large-v3
|
154 |
-
model = WhisperModel("large-v3", device=device, compute_type="auto")
|
155 |
-
# vad_filter=True:用 Silero VAD 过滤静音/噪音
|
156 |
-
segments, _info = model.transcribe(
|
157 |
-
audio_path,
|
158 |
-
vad_filter=True,
|
159 |
-
vad_parameters=dict(min_silence_duration_ms=int(PAUSE_SHORT * 1000)),
|
160 |
-
beam_size=5,
|
161 |
-
best_of=5,
|
162 |
-
)
|
163 |
-
chunks = []
|
164 |
-
for seg in segments:
|
165 |
-
chunks.append({
|
166 |
-
"text": seg.text.strip(),
|
167 |
-
"timestamp": [float(seg.start or 0.0), float(seg.end or 0.0)],
|
168 |
-
})
|
169 |
-
return chunks
|
170 |
-
|
171 |
-
# ====== Transformers 回退方案 ======
|
172 |
-
from transformers import pipeline as hf_pipeline
|
173 |
-
def _hf_transcribe_to_chunks(audio_path: str):
|
174 |
-
pipe = hf_pipeline(
|
175 |
-
task="automatic-speech-recognition",
|
176 |
-
model=ASR_MODEL,
|
177 |
-
chunk_length_s=30,
|
178 |
-
device=0 if torch.cuda.is_available() else "cpu",
|
179 |
-
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
180 |
-
return_timestamps=True,
|
181 |
-
)
|
182 |
-
result = pipe(audio_path, batch_size=BATCH_SIZE, generate_kwargs={"task": "transcribe"})
|
183 |
-
# 期望返回形如:{"text": "…", "chunks": [{"text": "...", "timestamp": (s,e)}, ...]}
|
184 |
-
chunks = result.get("chunks") or []
|
185 |
-
# 如无 chunks,用整段兜底
|
186 |
-
if not chunks:
|
187 |
-
total_text = (result.get("text") or "").strip()
|
188 |
-
if total_text:
|
189 |
-
chunks = [{"text": total_text, "timestamp": (0.0, max(MAX_SEG_DUR, 2.0))}]
|
190 |
-
# 统一结构
|
191 |
-
norm = []
|
192 |
-
for ch in chunks:
|
193 |
-
ts = ch.get("timestamp") or ch.get("timestamps") or [0.0, 2.0]
|
194 |
-
if isinstance(ts, (list, tuple)) and len(ts) == 2:
|
195 |
-
s, e = ts
|
196 |
-
else:
|
197 |
-
s, e = 0.0, 2.0
|
198 |
-
norm.append({"text": (ch.get("text") or "").strip(), "timestamp": [float(s or 0.0), float(e or 0.0)]})
|
199 |
-
return norm
|
200 |
-
|
201 |
-
# ====== 主函数:上传音频 → SRT(无编号 + 简易标点 + 可选 VAD)======
|
202 |
@spaces.GPU
|
203 |
-
def transcribe_file_to_srt(audio_path: str, task: str
|
204 |
if not audio_path:
|
205 |
raise gr.Error("请先上传音频文件。")
|
206 |
try:
|
@@ -210,17 +123,14 @@ def transcribe_file_to_srt(audio_path: str, task: str, use_vad: bool):
|
|
210 |
except OSError:
|
211 |
pass
|
212 |
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
else:
|
217 |
-
chunks = _hf_transcribe_to_chunks(audio_path)
|
218 |
|
219 |
-
srt_str =
|
220 |
-
if not srt_str:
|
221 |
-
srt_str = "00:00:00,000 --> 00:00:02,000\n
|
222 |
|
223 |
-
# 输出文件
|
224 |
tmpdir = tempfile.mkdtemp(prefix="srt_")
|
225 |
base = os.path.splitext(os.path.basename(audio_path))[0] or "subtitle"
|
226 |
srt_path = os.path.join(tmpdir, f"{base}.srt")
|
@@ -229,23 +139,19 @@ def transcribe_file_to_srt(audio_path: str, task: str, use_vad: bool):
|
|
229 |
|
230 |
return srt_str, srt_path
|
231 |
|
232 |
-
#
|
233 |
demo = gr.Interface(
|
234 |
fn=transcribe_file_to_srt,
|
235 |
inputs=[
|
236 |
gr.Audio(sources="upload", type="filepath", label="Audio file"),
|
237 |
gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
|
238 |
-
gr.Checkbox(label="Use VAD + Auto Punctuation (simple)", value=True),
|
239 |
],
|
240 |
outputs=[
|
241 |
-
gr.Textbox(label="SRT Preview
|
242 |
gr.File(label="Download SRT"),
|
243 |
],
|
244 |
-
title="Upload Audio → SRT
|
245 |
-
description=(
|
246 |
-
"Optional VAD (via faster-whisper) for cleaner segments. "
|
247 |
-
"Adds simple punctuation by pause/length. Adjustable MAX_SEG_DUR / MAX_SEG_CHARS / PAUSE_*."
|
248 |
-
),
|
249 |
allow_flagging="never",
|
250 |
)
|
251 |
|
|
|
2 |
import spaces
|
3 |
import torch
|
4 |
import gradio as gr
|
5 |
+
from transformers import pipeline
|
6 |
import tempfile
|
7 |
import os
|
8 |
|
9 |
+
# ===== 参数 =====
|
10 |
+
MODEL_NAME = "openai/whisper-large-v3"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
BATCH_SIZE = 8
|
12 |
FILE_LIMIT_MB = 1000
|
13 |
+
MAX_SEG_DUR = 6.0
|
14 |
+
MAX_SEG_CHARS = 28
|
15 |
+
|
16 |
+
device = 0 if torch.cuda.is_available() else "cpu"
|
17 |
+
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
18 |
+
|
19 |
+
pipe = pipeline(
|
20 |
+
task="automatic-speech-recognition",
|
21 |
+
model=MODEL_NAME,
|
22 |
+
chunk_length_s=30,
|
23 |
+
device=device,
|
24 |
+
torch_dtype=dtype,
|
25 |
+
return_timestamps=True,
|
26 |
+
)
|
27 |
|
28 |
+
# ===== 时间戳格式化 =====
|
29 |
def _srt_timestamp(seconds: float | None) -> str:
|
30 |
if seconds is None or seconds < 0:
|
31 |
seconds = 0.0
|
|
|
35 |
s, ms = divmod(ms, 1000)
|
36 |
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
|
37 |
|
38 |
+
# ===== 文本切分 + 自动补标点 =====
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
def _split_text_units(txt: str, max_seg_chars: int) -> list[str]:
|
40 |
+
strong = "。!?.!?"
|
41 |
units, cur = [], []
|
42 |
for ch in txt:
|
43 |
cur.append(ch)
|
|
|
52 |
if len(u) <= max_seg_chars:
|
53 |
refined.append(u)
|
54 |
else:
|
55 |
+
# 长句继续切分,并自动补句号
|
56 |
for i in range(0, len(u), max_seg_chars):
|
57 |
+
piece = u[i:i+max_seg_chars].strip()
|
58 |
+
if piece and piece[-1] not in strong:
|
59 |
+
piece += "。"
|
60 |
+
refined.append(piece)
|
61 |
+
# 如果最后一段没有标点,补句号
|
62 |
+
if refined and refined[-1][-1] not in strong:
|
63 |
+
refined[-1] += "。"
|
64 |
return [x for x in refined if x]
|
65 |
|
66 |
+
# ===== chunks 转 SRT (无编号 + 自动标点) =====
|
67 |
+
def chunks_to_srt(chunks: list[dict],
|
68 |
+
max_seg_dur: float = MAX_SEG_DUR,
|
69 |
+
max_seg_chars: int = MAX_SEG_CHARS) -> str:
|
70 |
lines = []
|
|
|
|
|
71 |
for ch in chunks or []:
|
72 |
text = (ch.get("text") or "").strip()
|
73 |
if not text:
|
|
|
79 |
else:
|
80 |
c_start, c_end = 0.0, 2.0
|
81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
units = _split_text_units(text, max_seg_chars)
|
83 |
if not units:
|
84 |
units = [text]
|
|
|
88 |
|
89 |
cur_t = c_start
|
90 |
for u in units:
|
91 |
+
alloc = total_dur * (len(u) / total_chars)
|
92 |
+
alloc = max(alloc, 0.2)
|
|
|
93 |
if alloc <= max_seg_dur:
|
94 |
pieces = [u]
|
95 |
per = alloc
|
96 |
else:
|
97 |
+
# 再次切分,均匀分时长
|
98 |
+
smalls = [u[i:i+max_seg_chars] for i in range(0, len(u), max_seg_chars)]
|
99 |
+
pieces = [s.strip() + ("。" if not s.endswith("。") else "") for s in smalls if s.strip()]
|
100 |
+
per = min(max_seg_dur, alloc / max(1, len(pieces)))
|
101 |
+
|
102 |
+
for p in pieces:
|
103 |
+
if p and p[-1] not in "。!?.!?":
|
104 |
+
p += "。"
|
105 |
st = cur_t
|
106 |
en = st + per
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
lines.append(f"{_srt_timestamp(st)} --> {_srt_timestamp(en)}")
|
108 |
+
lines.append(p.strip())
|
109 |
lines.append("")
|
110 |
cur_t = en
|
111 |
|
112 |
return "\n".join(lines).strip() + ("\n" if lines else "")
|
113 |
|
114 |
+
# ===== 上传音频 → SRT =====
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
@spaces.GPU
|
116 |
+
def transcribe_file_to_srt(audio_path: str, task: str):
|
117 |
if not audio_path:
|
118 |
raise gr.Error("请先上传音频文件。")
|
119 |
try:
|
|
|
123 |
except OSError:
|
124 |
pass
|
125 |
|
126 |
+
result = pipe(audio_path, batch_size=BATCH_SIZE, generate_kwargs={"task": task})
|
127 |
+
text = result.get("text", "") or ""
|
128 |
+
chunks = result.get("chunks") or []
|
|
|
|
|
129 |
|
130 |
+
srt_str = chunks_to_srt(chunks)
|
131 |
+
if not srt_str and text.strip():
|
132 |
+
srt_str = "00:00:00,000 --> 00:00:02,000\n" + (text.strip() + "。") + "\n"
|
133 |
|
|
|
134 |
tmpdir = tempfile.mkdtemp(prefix="srt_")
|
135 |
base = os.path.splitext(os.path.basename(audio_path))[0] or "subtitle"
|
136 |
srt_path = os.path.join(tmpdir, f"{base}.srt")
|
|
|
139 |
|
140 |
return srt_str, srt_path
|
141 |
|
142 |
+
# ===== 界面 =====
|
143 |
demo = gr.Interface(
|
144 |
fn=transcribe_file_to_srt,
|
145 |
inputs=[
|
146 |
gr.Audio(sources="upload", type="filepath", label="Audio file"),
|
147 |
gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
|
|
|
148 |
],
|
149 |
outputs=[
|
150 |
+
gr.Textbox(label="Transcript (SRT Preview)", lines=18),
|
151 |
gr.File(label="Download SRT"),
|
152 |
],
|
153 |
+
title="Upload Audio → SRT Subtitle",
|
154 |
+
description=f"Upload an audio file to generate time-stamped SRT subtitles (auto punctuation, no numbering). Model: {MODEL_NAME}",
|
|
|
|
|
|
|
155 |
allow_flagging="never",
|
156 |
)
|
157 |
|