Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -6,30 +6,32 @@ from transformers import pipeline
|
|
6 |
import tempfile
|
7 |
import os
|
8 |
|
9 |
-
# ==================
|
10 |
MODEL_NAME = "openai/whisper-large-v3"
|
11 |
BATCH_SIZE = 8
|
12 |
FILE_LIMIT_MB = 1000
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
MIN_PIECE_DUR = 0.30 # 每句最小时长(秒),避免闪烁
|
18 |
-
STRONG_PUNCT = "。!?.!?"
|
19 |
|
|
|
20 |
device = 0 if torch.cuda.is_available() else "cpu"
|
21 |
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
22 |
|
|
|
23 |
asr = pipeline(
|
24 |
task="automatic-speech-recognition",
|
25 |
model=MODEL_NAME,
|
26 |
chunk_length_s=30,
|
27 |
device=device,
|
28 |
torch_dtype=dtype,
|
29 |
-
return_timestamps=True, #
|
30 |
)
|
31 |
|
|
|
32 |
def _ts(t: float | None) -> str:
|
|
|
33 |
if t is None or t < 0:
|
34 |
t = 0.0
|
35 |
ms = int(float(t) * 1000 + 0.5)
|
@@ -39,17 +41,20 @@ def _ts(t: float | None) -> str:
|
|
39 |
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
|
40 |
|
41 |
def _norm_chunks(chunks: list[dict]) -> list[dict]:
|
42 |
-
"""
|
|
|
|
|
|
|
43 |
out = []
|
44 |
for ch in chunks or []:
|
45 |
text = (ch.get("text") or "").strip()
|
46 |
-
ts = ch.get("timestamp") or ch.get("timestamps") or [0.0,
|
47 |
if not text:
|
48 |
continue
|
49 |
if isinstance(ts, (list, tuple)) and len(ts) == 2:
|
50 |
s, e = float(ts[0] or 0.0), float(ts[1] or 0.0)
|
51 |
else:
|
52 |
-
s, e = 0.0,
|
53 |
if e < s:
|
54 |
e = s
|
55 |
out.append({"text": text, "start": s, "end": e})
|
@@ -57,7 +62,7 @@ def _norm_chunks(chunks: list[dict]) -> list[dict]:
|
|
57 |
|
58 |
def _char_timeline(chunk: dict) -> list[tuple[str, float, float]]:
|
59 |
"""
|
60 |
-
|
61 |
返回 [(char, char_start, char_end), ...]
|
62 |
"""
|
63 |
text = chunk["text"]
|
@@ -73,147 +78,65 @@ def _char_timeline(chunk: dict) -> list[tuple[str, float, float]]:
|
|
73 |
cur = nxt
|
74 |
return timeline
|
75 |
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
"""
|
81 |
-
|
82 |
-
-
|
83 |
-
-
|
84 |
-
-
|
85 |
-
- 超长句再按时长 <= max_seg_dur 均匀切
|
86 |
"""
|
87 |
segments = []
|
88 |
-
|
89 |
buf_chars: list[str] = []
|
90 |
-
buf_start = None
|
91 |
-
|
92 |
-
|
93 |
|
94 |
-
def
|
95 |
-
nonlocal buf_chars, buf_start,
|
96 |
if not buf_chars:
|
97 |
return
|
98 |
text = "".join(buf_chars).strip()
|
99 |
-
if not text:
|
100 |
-
buf_chars = []
|
101 |
-
buf_start = None
|
102 |
-
since_last_comma = 0
|
103 |
-
return
|
104 |
-
# 保证句末有强标点
|
105 |
-
if force_punct and text[-1] not in STRONG_PUNCT:
|
106 |
-
text += "。"
|
107 |
-
elif text[-1] not in STRONG_PUNCT:
|
108 |
-
text += "。"
|
109 |
st = buf_start if buf_start is not None else 0.0
|
110 |
-
en =
|
111 |
-
|
112 |
-
|
113 |
-
en = st + MIN_PIECE_DUR
|
114 |
segments.append((st, en, text))
|
115 |
-
buf_chars = []
|
116 |
-
buf_start = None
|
117 |
-
since_last_comma = 0
|
118 |
-
|
119 |
-
def try_hard_wrap_long(st: float, en: float, text: str):
|
120 |
-
"""
|
121 |
-
如果单句太长(> max_seg_dur),按时长把文本均匀切成多块,每块 <= max_seg_dur,句末补句号。
|
122 |
-
使用 [st,en] 线性映射。
|
123 |
-
"""
|
124 |
-
out = []
|
125 |
-
dur = max(en - st, 0.0)
|
126 |
-
if dur <= max_seg_dur:
|
127 |
-
return [(st, en, text if text[-1] in STRONG_PUNCT else (text + "。"))]
|
128 |
|
129 |
-
|
130 |
-
|
131 |
-
# 按字符再均匀切
|
132 |
-
L = len(text)
|
133 |
-
piece_len = max(L // k, 1)
|
134 |
-
pos = 0
|
135 |
-
for i in range(k):
|
136 |
-
sub = text[pos:pos + piece_len]
|
137 |
-
if not sub:
|
138 |
-
continue
|
139 |
-
sub_st = st + (i / k) * dur
|
140 |
-
sub_en = st + ((i + 1) / k) * dur
|
141 |
-
if sub[-1] not in STRONG_PUNCT:
|
142 |
-
sub += "。"
|
143 |
-
if sub_en - sub_st < MIN_PIECE_DUR:
|
144 |
-
sub_en = sub_st + MIN_PIECE_DUR
|
145 |
-
out.append((sub_st, sub_en, sub))
|
146 |
-
pos += piece_len
|
147 |
-
# 余数
|
148 |
-
if pos < L:
|
149 |
-
sub = text[pos:]
|
150 |
-
sub_st = st + (len("".join([t for _,_,t in out])) / max(L,1)) * dur
|
151 |
-
sub_en = en
|
152 |
-
if sub and sub[-1] not in STRONG_PUNCT:
|
153 |
-
sub += "。"
|
154 |
-
if sub_en - sub_st < MIN_PIECE_DUR:
|
155 |
-
sub_en = sub_st + MIN_PIECE_DUR
|
156 |
-
out.append((sub_st, sub_en, sub))
|
157 |
-
return out
|
158 |
-
|
159 |
-
# 遍历逐字时间线
|
160 |
-
for ch, ch_st, ch_en in char_stream:
|
161 |
if buf_start is None:
|
162 |
-
buf_start =
|
163 |
-
buf_chars.append(ch)
|
164 |
-
last_char_end = ch_en
|
165 |
-
since_last_comma += 1
|
166 |
|
167 |
-
#
|
168 |
-
if
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
# 逗号式短停顿(可选)
|
173 |
-
if comma_every and since_last_comma >= comma_every:
|
174 |
-
# 只在当前累积达到目标一半以上时才加逗号,避免太碎
|
175 |
-
if len(buf_chars) >= max(6, target_chars // 2):
|
176 |
-
if buf_chars and buf_chars[-1] not in ",,、;;" and buf_chars[-1] not in STRONG_PUNCT:
|
177 |
-
buf_chars.append(",")
|
178 |
-
flush_sentence(force_punct=False)
|
179 |
-
continue
|
180 |
-
else:
|
181 |
-
since_last_comma = 0 # 重置计数,继续攒
|
182 |
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
|
187 |
# 收尾
|
188 |
-
|
189 |
-
|
190 |
-
# 二次处理:把任何超时长的句子再按时长切块(<= MAX_SEG_DUR)
|
191 |
-
final_segments = []
|
192 |
-
for st, en, tx in segments:
|
193 |
-
if en - st > max_seg_dur:
|
194 |
-
final_segments.extend(try_hard_wrap_long(st, en, tx))
|
195 |
-
else:
|
196 |
-
final_segments.append((st, en, tx if tx[-1] in STRONG_PUNCT else (tx + "。")))
|
197 |
-
|
198 |
-
return final_segments
|
199 |
|
200 |
def chunks_to_srt_no_number(chunks: list[dict]) -> str:
|
201 |
"""
|
202 |
-
|
203 |
"""
|
204 |
norm = _norm_chunks(chunks)
|
205 |
-
|
|
|
206 |
char_stream = []
|
207 |
for ch in norm:
|
208 |
char_stream.extend(_char_timeline(ch))
|
209 |
|
210 |
-
#
|
211 |
-
segs =
|
212 |
-
char_stream,
|
213 |
-
target_chars=TARGET_SENT_CHARS,
|
214 |
-
comma_every=COMMA_EVERY,
|
215 |
-
max_seg_dur=MAX_SEG_DUR,
|
216 |
-
)
|
217 |
|
218 |
# 输出(无编号)
|
219 |
lines = []
|
@@ -223,7 +146,7 @@ def chunks_to_srt_no_number(chunks: list[dict]) -> str:
|
|
223 |
lines.append("")
|
224 |
return "\n".join(lines).strip() + ("\n" if lines else "")
|
225 |
|
226 |
-
# ================== 推理与UI ==================
|
227 |
@spaces.GPU
|
228 |
def transcribe_file_to_srt(audio_path: str, task: str):
|
229 |
if not audio_path:
|
@@ -233,16 +156,22 @@ def transcribe_file_to_srt(audio_path: str, task: str):
|
|
233 |
if size_mb > FILE_LIMIT_MB:
|
234 |
raise gr.Error(f"文件过大:{size_mb:.1f} MB,超过限制 {FILE_LIMIT_MB} MB。")
|
235 |
except OSError:
|
|
|
236 |
pass
|
237 |
|
|
|
238 |
result = asr(audio_path, batch_size=BATCH_SIZE, generate_kwargs={"task": task})
|
239 |
text = (result.get("text") or "").strip()
|
240 |
chunks = result.get("chunks") or []
|
241 |
|
|
|
242 |
srt_str = chunks_to_srt_no_number(chunks)
|
|
|
|
|
243 |
if not srt_str and text:
|
244 |
-
srt_str = "00:00:00,000 --> 00:00:02,000\n" +
|
245 |
|
|
|
246 |
tmpdir = tempfile.mkdtemp(prefix="srt_")
|
247 |
base = os.path.splitext(os.path.basename(audio_path))[0] or "subtitle"
|
248 |
srt_path = os.path.join(tmpdir, f"{base}.srt")
|
@@ -258,12 +187,12 @@ demo = gr.Interface(
|
|
258 |
gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
|
259 |
],
|
260 |
outputs=[
|
261 |
-
gr.Textbox(label="Transcript (SRT Preview —
|
262 |
gr.File(label="Download SRT"),
|
263 |
],
|
264 |
-
title="Upload Audio → SRT (
|
265 |
-
description=f"
|
266 |
allow_flagging="never",
|
267 |
)
|
268 |
|
269 |
-
demo.queue().launch()
|
|
|
6 |
import tempfile
|
7 |
import os
|
8 |
|
9 |
+
# ================== 配置与可调参数 ==================
|
10 |
MODEL_NAME = "openai/whisper-large-v3"
|
11 |
BATCH_SIZE = 8
|
12 |
FILE_LIMIT_MB = 1000
|
13 |
|
14 |
+
# —— 静音切分策略(本次需求)——
|
15 |
+
SILENCE_GAP = 0.2 # 相邻字符开始时间的间隔 >= 0.2s 触发切分
|
16 |
+
MIN_SEG_DUR = 0.30 # 每段最小时长,避免字幕闪烁
|
|
|
|
|
17 |
|
18 |
+
# 设备/精度
|
19 |
device = 0 if torch.cuda.is_available() else "cpu"
|
20 |
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
21 |
|
22 |
+
# ASR 推理器:开启 chunk 级时间戳
|
23 |
asr = pipeline(
|
24 |
task="automatic-speech-recognition",
|
25 |
model=MODEL_NAME,
|
26 |
chunk_length_s=30,
|
27 |
device=device,
|
28 |
torch_dtype=dtype,
|
29 |
+
return_timestamps=True, # 返回 chunk 级别时间戳(start, end)
|
30 |
)
|
31 |
|
32 |
+
# ================== 工具函数 ==================
|
33 |
def _ts(t: float | None) -> str:
|
34 |
+
"""将秒数转为 SRT 时间戳 00:00:00,000"""
|
35 |
if t is None or t < 0:
|
36 |
t = 0.0
|
37 |
ms = int(float(t) * 1000 + 0.5)
|
|
|
41 |
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
|
42 |
|
43 |
def _norm_chunks(chunks: list[dict]) -> list[dict]:
|
44 |
+
"""
|
45 |
+
规范化 chunks: [{'text': str, 'start': float, 'end': float}]
|
46 |
+
兼容不同字段名:'timestamp' 或 'timestamps'
|
47 |
+
"""
|
48 |
out = []
|
49 |
for ch in chunks or []:
|
50 |
text = (ch.get("text") or "").strip()
|
51 |
+
ts = ch.get("timestamp") or ch.get("timestamps") or [0.0, 0.0]
|
52 |
if not text:
|
53 |
continue
|
54 |
if isinstance(ts, (list, tuple)) and len(ts) == 2:
|
55 |
s, e = float(ts[0] or 0.0), float(ts[1] or 0.0)
|
56 |
else:
|
57 |
+
s, e = 0.0, 0.0
|
58 |
if e < s:
|
59 |
e = s
|
60 |
out.append({"text": text, "start": s, "end": e})
|
|
|
62 |
|
63 |
def _char_timeline(chunk: dict) -> list[tuple[str, float, float]]:
|
64 |
"""
|
65 |
+
对单个 chunk 逐字符线性插值时间:
|
66 |
返回 [(char, char_start, char_end), ...]
|
67 |
"""
|
68 |
text = chunk["text"]
|
|
|
78 |
cur = nxt
|
79 |
return timeline
|
80 |
|
81 |
+
# ================== 按静音间隔切分(不依赖标点) ==================
|
82 |
+
def _segment_by_silence(char_stream: list[tuple[str, float, float]],
|
83 |
+
silence_gap: float = SILENCE_GAP,
|
84 |
+
min_seg_dur: float = MIN_SEG_DUR) -> list[tuple[float, float, str]]:
|
85 |
"""
|
86 |
+
按相邻字符的“开始时间间隔”切分:
|
87 |
+
- 若当前字符开始时间 与 上一字符结束时间 的差值 >= silence_gap → 立刻切一段
|
88 |
+
- 不做标点处理,所有字符原样拼接
|
89 |
+
- 每段最小时长保护
|
|
|
90 |
"""
|
91 |
segments = []
|
|
|
92 |
buf_chars: list[str] = []
|
93 |
+
buf_start: float | None = None
|
94 |
+
last_end: float | None = None
|
95 |
+
prev_end: float | None = None
|
96 |
|
97 |
+
def flush():
|
98 |
+
nonlocal buf_chars, buf_start, last_end
|
99 |
if not buf_chars:
|
100 |
return
|
101 |
text = "".join(buf_chars).strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
st = buf_start if buf_start is not None else 0.0
|
103 |
+
en = last_end if last_end is not None else st
|
104 |
+
if en - st < min_seg_dur:
|
105 |
+
en = st + min_seg_dur
|
|
|
106 |
segments.append((st, en, text))
|
107 |
+
buf_chars, buf_start, last_end = [], None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
+
for ch, st, en in char_stream:
|
110 |
+
# 如果缓存为空,初始化起点
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
if buf_start is None:
|
112 |
+
buf_start = st
|
|
|
|
|
|
|
113 |
|
114 |
+
# 与上一字符“结束时间”的间隔用于判断静音切分
|
115 |
+
if prev_end is not None and (st - prev_end) >= silence_gap:
|
116 |
+
flush()
|
117 |
+
buf_start = st # 新段的起点从当前字符开始
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
+
buf_chars.append(ch)
|
120 |
+
last_end = en
|
121 |
+
prev_end = en
|
122 |
|
123 |
# 收尾
|
124 |
+
flush()
|
125 |
+
return segments
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
def chunks_to_srt_no_number(chunks: list[dict]) -> str:
|
128 |
"""
|
129 |
+
流程:规范化 chunks → 全局逐字符时间线 → 按静音切分 → 输出无编号 SRT
|
130 |
"""
|
131 |
norm = _norm_chunks(chunks)
|
132 |
+
|
133 |
+
# 拼接全局字符流
|
134 |
char_stream = []
|
135 |
for ch in norm:
|
136 |
char_stream.extend(_char_timeline(ch))
|
137 |
|
138 |
+
# 静音切分
|
139 |
+
segs = _segment_by_silence(char_stream)
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
# 输出(无编号)
|
142 |
lines = []
|
|
|
146 |
lines.append("")
|
147 |
return "\n".join(lines).strip() + ("\n" if lines else "")
|
148 |
|
149 |
+
# ================== 推理与 UI ==================
|
150 |
@spaces.GPU
|
151 |
def transcribe_file_to_srt(audio_path: str, task: str):
|
152 |
if not audio_path:
|
|
|
156 |
if size_mb > FILE_LIMIT_MB:
|
157 |
raise gr.Error(f"文件过大:{size_mb:.1f} MB,超过限制 {FILE_LIMIT_MB} MB。")
|
158 |
except OSError:
|
159 |
+
# 某些远端路径可能拿不到大小,忽略即可
|
160 |
pass
|
161 |
|
162 |
+
# 运行 ASR
|
163 |
result = asr(audio_path, batch_size=BATCH_SIZE, generate_kwargs={"task": task})
|
164 |
text = (result.get("text") or "").strip()
|
165 |
chunks = result.get("chunks") or []
|
166 |
|
167 |
+
# 基于静音切分生成 SRT(无编号)
|
168 |
srt_str = chunks_to_srt_no_number(chunks)
|
169 |
+
|
170 |
+
# 兜底:若无 chunks,则给出整段兜底字幕
|
171 |
if not srt_str and text:
|
172 |
+
srt_str = "00:00:00,000 --> 00:00:02,000\n" + text + "\n"
|
173 |
|
174 |
+
# 写入临时文件,供下载
|
175 |
tmpdir = tempfile.mkdtemp(prefix="srt_")
|
176 |
base = os.path.splitext(os.path.basename(audio_path))[0] or "subtitle"
|
177 |
srt_path = os.path.join(tmpdir, f"{base}.srt")
|
|
|
187 |
gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
|
188 |
],
|
189 |
outputs=[
|
190 |
+
gr.Textbox(label="Transcript (SRT Preview — silence-based, no numbering)", lines=18),
|
191 |
gr.File(label="Download SRT"),
|
192 |
],
|
193 |
+
title="Upload Audio → SRT (Silence-Based Segmentation, No Numbering)",
|
194 |
+
description=f"切分规则:相邻字符静音间隔 ≥ {SILENCE_GAP}s 则切段;不依赖标点。模型:{MODEL_NAME}",
|
195 |
allow_flagging="never",
|
196 |
)
|
197 |
|
198 |
+
demo.queue().launch()
|