Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -6,112 +6,224 @@ from transformers import pipeline
|
|
6 |
import tempfile
|
7 |
import os
|
8 |
|
9 |
-
#
|
10 |
MODEL_NAME = "openai/whisper-large-v3"
|
11 |
BATCH_SIZE = 8
|
12 |
FILE_LIMIT_MB = 1000
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
15 |
|
16 |
device = 0 if torch.cuda.is_available() else "cpu"
|
17 |
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
18 |
|
19 |
-
|
20 |
task="automatic-speech-recognition",
|
21 |
model=MODEL_NAME,
|
22 |
chunk_length_s=30,
|
23 |
device=device,
|
24 |
torch_dtype=dtype,
|
25 |
-
return_timestamps=True,
|
26 |
)
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
ms = int(float(seconds) * 1000 + 0.5)
|
33 |
h, ms = divmod(ms, 3600000)
|
34 |
m, ms = divmod(ms, 60000)
|
35 |
s, ms = divmod(ms, 1000)
|
36 |
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
units, cur = [], []
|
42 |
-
for ch in txt:
|
43 |
-
cur.append(ch)
|
44 |
-
if ch in strong:
|
45 |
-
units.append("".join(cur).strip())
|
46 |
-
cur = []
|
47 |
-
if cur:
|
48 |
-
units.append("".join(cur).strip())
|
49 |
-
|
50 |
-
refined = []
|
51 |
-
for u in units:
|
52 |
-
if len(u) <= max_seg_chars:
|
53 |
-
refined.append(u)
|
54 |
-
else:
|
55 |
-
# 长句继续切分,并自动补句号
|
56 |
-
for i in range(0, len(u), max_seg_chars):
|
57 |
-
piece = u[i:i+max_seg_chars].strip()
|
58 |
-
if piece and piece[-1] not in strong:
|
59 |
-
piece += "。"
|
60 |
-
refined.append(piece)
|
61 |
-
# 如果最后一段没有标点,补句号
|
62 |
-
if refined and refined[-1][-1] not in strong:
|
63 |
-
refined[-1] += "。"
|
64 |
-
return [x for x in refined if x]
|
65 |
-
|
66 |
-
# ===== chunks 转 SRT (无编号 + 自动标点) =====
|
67 |
-
def chunks_to_srt(chunks: list[dict],
|
68 |
-
max_seg_dur: float = MAX_SEG_DUR,
|
69 |
-
max_seg_chars: int = MAX_SEG_CHARS) -> str:
|
70 |
-
lines = []
|
71 |
for ch in chunks or []:
|
72 |
text = (ch.get("text") or "").strip()
|
|
|
73 |
if not text:
|
74 |
continue
|
75 |
-
|
76 |
-
ts = ch.get("timestamp") or ch.get("timestamps") or [0.0, 2.0]
|
77 |
if isinstance(ts, (list, tuple)) and len(ts) == 2:
|
78 |
-
|
79 |
else:
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
else:
|
97 |
-
#
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
st = cur_t
|
106 |
-
en = st + per
|
107 |
-
lines.append(f"{_srt_timestamp(st)} --> {_srt_timestamp(en)}")
|
108 |
-
lines.append(p.strip())
|
109 |
-
lines.append("")
|
110 |
-
cur_t = en
|
111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
return "\n".join(lines).strip() + ("\n" if lines else "")
|
113 |
|
114 |
-
#
|
115 |
@spaces.GPU
|
116 |
def transcribe_file_to_srt(audio_path: str, task: str):
|
117 |
if not audio_path:
|
@@ -123,13 +235,13 @@ def transcribe_file_to_srt(audio_path: str, task: str):
|
|
123 |
except OSError:
|
124 |
pass
|
125 |
|
126 |
-
result =
|
127 |
-
text = result.get("text"
|
128 |
chunks = result.get("chunks") or []
|
129 |
|
130 |
-
srt_str =
|
131 |
-
if not srt_str and text
|
132 |
-
srt_str = "00:00:00,000 --> 00:00:02,000\n" + (text
|
133 |
|
134 |
tmpdir = tempfile.mkdtemp(prefix="srt_")
|
135 |
base = os.path.splitext(os.path.basename(audio_path))[0] or "subtitle"
|
@@ -139,7 +251,6 @@ def transcribe_file_to_srt(audio_path: str, task: str):
|
|
139 |
|
140 |
return srt_str, srt_path
|
141 |
|
142 |
-
# ===== 界面 =====
|
143 |
demo = gr.Interface(
|
144 |
fn=transcribe_file_to_srt,
|
145 |
inputs=[
|
@@ -147,12 +258,12 @@ demo = gr.Interface(
|
|
147 |
gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
|
148 |
],
|
149 |
outputs=[
|
150 |
-
gr.Textbox(label="Transcript (SRT Preview)", lines=18),
|
151 |
gr.File(label="Download SRT"),
|
152 |
],
|
153 |
-
title="Upload Audio → SRT
|
154 |
-
description=f"
|
155 |
allow_flagging="never",
|
156 |
)
|
157 |
|
158 |
-
demo.queue().launch()
|
|
|
6 |
import tempfile
|
7 |
import os
|
8 |
|
9 |
+
# ================== 可调参数 ==================
|
10 |
MODEL_NAME = "openai/whisper-large-v3"
|
11 |
BATCH_SIZE = 8
|
12 |
FILE_LIMIT_MB = 1000
|
13 |
+
|
14 |
+
TARGET_SENT_CHARS = 12 # 目标每句字数(中文场景)
|
15 |
+
COMMA_EVERY = 0 # 如需更细粒度短停顿,可设 6/8(表示每 N 字加一个“,”并收句);0 表示关闭
|
16 |
+
MAX_SEG_DUR = 6.0 # 每句最长时长(秒)
|
17 |
+
MIN_PIECE_DUR = 0.30 # 每句最小时长(秒),避免闪烁
|
18 |
+
STRONG_PUNCT = "。!?.!?"
|
19 |
|
20 |
device = 0 if torch.cuda.is_available() else "cpu"
|
21 |
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
22 |
|
23 |
+
asr = pipeline(
|
24 |
task="automatic-speech-recognition",
|
25 |
model=MODEL_NAME,
|
26 |
chunk_length_s=30,
|
27 |
device=device,
|
28 |
torch_dtype=dtype,
|
29 |
+
return_timestamps=True, # 仅需 chunk 级 (start,end)
|
30 |
)
|
31 |
|
32 |
+
def _ts(t: float | None) -> str:
|
33 |
+
if t is None or t < 0:
|
34 |
+
t = 0.0
|
35 |
+
ms = int(float(t) * 1000 + 0.5)
|
|
|
36 |
h, ms = divmod(ms, 3600000)
|
37 |
m, ms = divmod(ms, 60000)
|
38 |
s, ms = divmod(ms, 1000)
|
39 |
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
|
40 |
|
41 |
+
def _norm_chunks(chunks: list[dict]) -> list[dict]:
|
42 |
+
"""规范化 chunks: [{'text': str, 'start': float, 'end': float}]"""
|
43 |
+
out = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
for ch in chunks or []:
|
45 |
text = (ch.get("text") or "").strip()
|
46 |
+
ts = ch.get("timestamp") or ch.get("timestamps") or [0.0, 2.0]
|
47 |
if not text:
|
48 |
continue
|
|
|
|
|
49 |
if isinstance(ts, (list, tuple)) and len(ts) == 2:
|
50 |
+
s, e = float(ts[0] or 0.0), float(ts[1] or 0.0)
|
51 |
else:
|
52 |
+
s, e = 0.0, 2.0
|
53 |
+
if e < s:
|
54 |
+
e = s
|
55 |
+
out.append({"text": text, "start": s, "end": e})
|
56 |
+
return out
|
57 |
+
|
58 |
+
def _char_timeline(chunk: dict) -> list[tuple[str, float, float]]:
|
59 |
+
"""
|
60 |
+
把一个 chunk 的文本按字符建立时间轴:
|
61 |
+
返回 [(char, char_start, char_end), ...]
|
62 |
+
"""
|
63 |
+
text = chunk["text"]
|
64 |
+
s, e = chunk["start"], chunk["end"]
|
65 |
+
dur = max(e - s, 0.0)
|
66 |
+
n = max(len(text), 1)
|
67 |
+
step = dur / n if n > 0 else 0.0
|
68 |
+
timeline = []
|
69 |
+
cur = s
|
70 |
+
for i, ch in enumerate(text):
|
71 |
+
nxt = s + (i + 1) * step
|
72 |
+
timeline.append((ch, cur, nxt))
|
73 |
+
cur = nxt
|
74 |
+
return timeline
|
75 |
+
|
76 |
+
def _segment_short_sentences(char_stream: list[tuple[str, float, float]],
|
77 |
+
target_chars: int = TARGET_SENT_CHARS,
|
78 |
+
comma_every: int = COMMA_EVERY,
|
79 |
+
max_seg_dur: float = MAX_SEG_DUR) -> list[tuple[float, float, str]]:
|
80 |
+
"""
|
81 |
+
核心切分:
|
82 |
+
- 累积字符直到遇到强标点 或 达到 target_chars
|
83 |
+
- 可选:每 comma_every 个字符插入逗号并收句
|
84 |
+
- 强标点永远并入本句,绝不产生“单独标点句”
|
85 |
+
- 超长句再按时长 <= max_seg_dur 均匀切
|
86 |
+
"""
|
87 |
+
segments = []
|
88 |
+
|
89 |
+
buf_chars: list[str] = []
|
90 |
+
buf_start = None
|
91 |
+
last_char_end = None
|
92 |
+
since_last_comma = 0
|
93 |
+
|
94 |
+
def flush_sentence(force_punct=False):
|
95 |
+
nonlocal buf_chars, buf_start, last_char_end, since_last_comma
|
96 |
+
if not buf_chars:
|
97 |
+
return
|
98 |
+
text = "".join(buf_chars).strip()
|
99 |
+
if not text:
|
100 |
+
buf_chars = []
|
101 |
+
buf_start = None
|
102 |
+
since_last_comma = 0
|
103 |
+
return
|
104 |
+
# 保证句末有强标点
|
105 |
+
if force_punct and text[-1] not in STRONG_PUNCT:
|
106 |
+
text += "。"
|
107 |
+
elif text[-1] not in STRONG_PUNCT:
|
108 |
+
text += "。"
|
109 |
+
st = buf_start if buf_start is not None else 0.0
|
110 |
+
en = last_char_end if last_char_end is not None else st
|
111 |
+
# 时长保护
|
112 |
+
if en - st < MIN_PIECE_DUR:
|
113 |
+
en = st + MIN_PIECE_DUR
|
114 |
+
segments.append((st, en, text))
|
115 |
+
buf_chars = []
|
116 |
+
buf_start = None
|
117 |
+
since_last_comma = 0
|
118 |
+
|
119 |
+
def try_hard_wrap_long(st: float, en: float, text: str):
|
120 |
+
"""
|
121 |
+
如果单句太长(> max_seg_dur),按时长把文本均匀切成多块,每块 <= max_seg_dur,句末补句号。
|
122 |
+
使用 [st,en] 线性映射。
|
123 |
+
"""
|
124 |
+
out = []
|
125 |
+
dur = max(en - st, 0.0)
|
126 |
+
if dur <= max_seg_dur:
|
127 |
+
return [(st, en, text if text[-1] in STRONG_PUNCT else (text + "。"))]
|
128 |
+
|
129 |
+
# 需要切成 k 块
|
130 |
+
k = int(dur // max_seg_dur) + 1
|
131 |
+
# 按字符再均匀切
|
132 |
+
L = len(text)
|
133 |
+
piece_len = max(L // k, 1)
|
134 |
+
pos = 0
|
135 |
+
for i in range(k):
|
136 |
+
sub = text[pos:pos + piece_len]
|
137 |
+
if not sub:
|
138 |
+
continue
|
139 |
+
sub_st = st + (i / k) * dur
|
140 |
+
sub_en = st + ((i + 1) / k) * dur
|
141 |
+
if sub[-1] not in STRONG_PUNCT:
|
142 |
+
sub += "。"
|
143 |
+
if sub_en - sub_st < MIN_PIECE_DUR:
|
144 |
+
sub_en = sub_st + MIN_PIECE_DUR
|
145 |
+
out.append((sub_st, sub_en, sub))
|
146 |
+
pos += piece_len
|
147 |
+
# 余数
|
148 |
+
if pos < L:
|
149 |
+
sub = text[pos:]
|
150 |
+
sub_st = st + (len("".join([t for _,_,t in out])) / max(L,1)) * dur
|
151 |
+
sub_en = en
|
152 |
+
if sub and sub[-1] not in STRONG_PUNCT:
|
153 |
+
sub += "。"
|
154 |
+
if sub_en - sub_st < MIN_PIECE_DUR:
|
155 |
+
sub_en = sub_st + MIN_PIECE_DUR
|
156 |
+
out.append((sub_st, sub_en, sub))
|
157 |
+
return out
|
158 |
+
|
159 |
+
# 遍历逐字时间线
|
160 |
+
for ch, ch_st, ch_en in char_stream:
|
161 |
+
if buf_start is None:
|
162 |
+
buf_start = ch_st
|
163 |
+
buf_chars.append(ch)
|
164 |
+
last_char_end = ch_en
|
165 |
+
since_last_comma += 1
|
166 |
+
|
167 |
+
# 强标点:直接并入当前句并收句
|
168 |
+
if ch in STRONG_PUNCT:
|
169 |
+
flush_sentence(force_punct=False)
|
170 |
+
continue
|
171 |
+
|
172 |
+
# 逗号式短停顿(可选)
|
173 |
+
if comma_every and since_last_comma >= comma_every:
|
174 |
+
# 只在当前累积达到目标一半以上时才加逗号,避免太碎
|
175 |
+
if len(buf_chars) >= max(6, target_chars // 2):
|
176 |
+
if buf_chars and buf_chars[-1] not in ",,、;;" and buf_chars[-1] not in STRONG_PUNCT:
|
177 |
+
buf_chars.append(",")
|
178 |
+
flush_sentence(force_punct=False)
|
179 |
+
continue
|
180 |
else:
|
181 |
+
since_last_comma = 0 # 重置计数,继续攒
|
182 |
+
|
183 |
+
# 达到目标长度:收句并补句号
|
184 |
+
if len(buf_chars) >= target_chars:
|
185 |
+
flush_sentence(force_punct=True)
|
186 |
+
|
187 |
+
# 收尾
|
188 |
+
flush_sentence(force_punct=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
|
190 |
+
# 二次处理:把任何超时长的句子再按时长切块(<= MAX_SEG_DUR)
|
191 |
+
final_segments = []
|
192 |
+
for st, en, tx in segments:
|
193 |
+
if en - st > max_seg_dur:
|
194 |
+
final_segments.extend(try_hard_wrap_long(st, en, tx))
|
195 |
+
else:
|
196 |
+
final_segments.append((st, en, tx if tx[-1] in STRONG_PUNCT else (tx + "。")))
|
197 |
+
|
198 |
+
return final_segments
|
199 |
+
|
200 |
+
def chunks_to_srt_no_number(chunks: list[dict]) -> str:
|
201 |
+
"""
|
202 |
+
外层封装:逐 chunk 建立字符时间线 → 合并 → 切分 → 输出无编号 SRT。
|
203 |
+
"""
|
204 |
+
norm = _norm_chunks(chunks)
|
205 |
+
# 构建全局逐字时间线(按 chunk 顺序拼接)
|
206 |
+
char_stream = []
|
207 |
+
for ch in norm:
|
208 |
+
char_stream.extend(_char_timeline(ch))
|
209 |
+
|
210 |
+
# 切分为短句片段
|
211 |
+
segs = _segment_short_sentences(
|
212 |
+
char_stream,
|
213 |
+
target_chars=TARGET_SENT_CHARS,
|
214 |
+
comma_every=COMMA_EVERY,
|
215 |
+
max_seg_dur=MAX_SEG_DUR,
|
216 |
+
)
|
217 |
+
|
218 |
+
# 输出(无编号)
|
219 |
+
lines = []
|
220 |
+
for st, en, tx in segs:
|
221 |
+
lines.append(f"{_ts(st)} --> {_ts(en)}")
|
222 |
+
lines.append(tx.strip())
|
223 |
+
lines.append("")
|
224 |
return "\n".join(lines).strip() + ("\n" if lines else "")
|
225 |
|
226 |
+
# ================== 推理与UI ==================
|
227 |
@spaces.GPU
|
228 |
def transcribe_file_to_srt(audio_path: str, task: str):
|
229 |
if not audio_path:
|
|
|
235 |
except OSError:
|
236 |
pass
|
237 |
|
238 |
+
result = asr(audio_path, batch_size=BATCH_SIZE, generate_kwargs={"task": task})
|
239 |
+
text = (result.get("text") or "").strip()
|
240 |
chunks = result.get("chunks") or []
|
241 |
|
242 |
+
srt_str = chunks_to_srt_no_number(chunks)
|
243 |
+
if not srt_str and text:
|
244 |
+
srt_str = "00:00:00,000 --> 00:00:02,000\n" + (text + ("。" if text[-1] not in STRONG_PUNCT else "")) + "\n"
|
245 |
|
246 |
tmpdir = tempfile.mkdtemp(prefix="srt_")
|
247 |
base = os.path.splitext(os.path.basename(audio_path))[0] or "subtitle"
|
|
|
251 |
|
252 |
return srt_str, srt_path
|
253 |
|
|
|
254 |
demo = gr.Interface(
|
255 |
fn=transcribe_file_to_srt,
|
256 |
inputs=[
|
|
|
258 |
gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
|
259 |
],
|
260 |
outputs=[
|
261 |
+
gr.Textbox(label="Transcript (SRT Preview — short sentences, no numbering)", lines=18),
|
262 |
gr.File(label="Download SRT"),
|
263 |
],
|
264 |
+
title="Upload Audio → SRT (Short Sentences, No Numbering)",
|
265 |
+
description=f"Character-timeline resegmentation. Natural short sentences like “他跟三国志他不一样。/ 他也是在那个基础上。/ …”. Model: {MODEL_NAME}",
|
266 |
allow_flagging="never",
|
267 |
)
|
268 |
|
269 |
+
demo.queue().launch()
|