Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,16 +1,17 @@
|
|
|
|
1 |
import spaces
|
2 |
import torch
|
3 |
import gradio as gr
|
4 |
from transformers import pipeline
|
5 |
-
|
6 |
import tempfile
|
7 |
import os
|
8 |
-
from datetime import timedelta
|
9 |
|
10 |
-
# =====
|
11 |
MODEL_NAME = "openai/whisper-large-v3"
|
12 |
BATCH_SIZE = 8
|
13 |
-
FILE_LIMIT_MB = 1000
|
|
|
|
|
14 |
|
15 |
device = 0 if torch.cuda.is_available() else "cpu"
|
16 |
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
@@ -21,12 +22,11 @@ pipe = pipeline(
|
|
21 |
chunk_length_s=30,
|
22 |
device=device,
|
23 |
torch_dtype=dtype,
|
24 |
-
return_timestamps=
|
25 |
)
|
26 |
|
27 |
-
# =====
|
28 |
-
def _srt_timestamp(seconds):
|
29 |
-
"""秒 -> SRT 时间戳 00:00:00,000。None/负数时归零。"""
|
30 |
if seconds is None or seconds < 0:
|
31 |
seconds = 0.0
|
32 |
ms = int(float(seconds) * 1000 + 0.5)
|
@@ -35,92 +35,83 @@ def _srt_timestamp(seconds):
|
|
35 |
s, ms = divmod(ms, 1000)
|
36 |
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
if
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
if isinstance(st, (list, tuple)): st = st[0]
|
72 |
-
if isinstance(en, (list, tuple)): en = en[-1]
|
73 |
-
dur = float((en or 0.0) - (st or 0.0))
|
74 |
-
if force or strong_punct or dur >= max_seg_dur or cur_len >= max_seg_chars:
|
75 |
-
flush_seg()
|
76 |
-
|
77 |
-
# 汇总所有词
|
78 |
-
all_words = []
|
79 |
for ch in chunks or []:
|
80 |
-
|
81 |
-
if not
|
82 |
-
ts = ch.get("timestamp") or ch.get("timestamps") or [0.0, 2.0]
|
83 |
-
if isinstance(ts, (list, tuple)) and len(ts) == 2:
|
84 |
-
all_words.append({"word": ch["text"], "start": ts[0], "end": ts[1]})
|
85 |
-
else:
|
86 |
-
all_words.append({"word": ch["text"], "start": 0.0, "end": 2.0})
|
87 |
continue
|
88 |
-
for w in words:
|
89 |
-
token = (w.get("word") or "").replace("\n", " ")
|
90 |
-
start = w.get("start")
|
91 |
-
end = w.get("end")
|
92 |
-
if (start is None or end is None) and isinstance(w.get("timestamp"), (list, tuple)) and len(w["timestamp"]) == 2:
|
93 |
-
start, end = w["timestamp"]
|
94 |
-
all_words.append({"word": token, "start": start, "end": end})
|
95 |
-
|
96 |
-
# 若依旧拿不到逐词,回退整段文本
|
97 |
-
if not all_words and text_fallback.strip():
|
98 |
-
all_words = [{"word": text_fallback.strip(), "start": 0.0, "end": max_seg_dur}]
|
99 |
-
|
100 |
-
# 按规则切分
|
101 |
-
for w in all_words:
|
102 |
-
token = w.get("word", "")
|
103 |
-
if not token:
|
104 |
-
continue
|
105 |
-
if cur_start is None:
|
106 |
-
cur_start = w.get("start", 0.0)
|
107 |
-
cur_words.append(w)
|
108 |
-
cur_len += len(token)
|
109 |
-
strong = token.endswith(("。", "!", "?", ".", "!", "?"))
|
110 |
-
maybe_flush(force=False, strong_punct=strong)
|
111 |
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
-
# 生成 SRT
|
115 |
-
lines = []
|
116 |
-
for i, (st, en, txt) in enumerate(segs, 1):
|
117 |
-
lines.append(str(i))
|
118 |
-
lines.append(f"{_srt_timestamp(st)} --> {_srt_timestamp(en)}")
|
119 |
-
lines.append(txt)
|
120 |
-
lines.append("")
|
121 |
return "\n".join(lines).strip() + ("\n" if lines else "")
|
122 |
|
123 |
-
# ===== 上传音频
|
124 |
@spaces.GPU
|
125 |
def transcribe_file_to_srt(audio_path: str, task: str):
|
126 |
if not audio_path:
|
@@ -136,20 +127,19 @@ def transcribe_file_to_srt(audio_path: str, task: str):
|
|
136 |
text = result.get("text", "") or ""
|
137 |
chunks = result.get("chunks") or []
|
138 |
|
139 |
-
|
140 |
-
srt_str
|
|
|
141 |
|
142 |
-
# 写入临时文件供下载
|
143 |
tmpdir = tempfile.mkdtemp(prefix="srt_")
|
144 |
base = os.path.splitext(os.path.basename(audio_path))[0] or "subtitle"
|
145 |
srt_path = os.path.join(tmpdir, f"{base}.srt")
|
146 |
with open(srt_path, "w", encoding="utf-8") as f:
|
147 |
f.write(srt_str)
|
148 |
|
149 |
-
# 第一个输出显示 SRT 字符串,第二个输出提供下载
|
150 |
return srt_str, srt_path
|
151 |
|
152 |
-
# =====
|
153 |
demo = gr.Interface(
|
154 |
fn=transcribe_file_to_srt,
|
155 |
inputs=[
|
@@ -161,10 +151,7 @@ demo = gr.Interface(
|
|
161 |
gr.File(label="Download SRT"),
|
162 |
],
|
163 |
title="Upload Audio → SRT Subtitle",
|
164 |
-
description=(
|
165 |
-
"Upload an audio file to generate time-stamped SRT subtitles. "
|
166 |
-
f"Backed by [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME})."
|
167 |
-
),
|
168 |
allow_flagging="never",
|
169 |
)
|
170 |
|
|
|
1 |
+
# app.py
|
2 |
import spaces
|
3 |
import torch
|
4 |
import gradio as gr
|
5 |
from transformers import pipeline
|
|
|
6 |
import tempfile
|
7 |
import os
|
|
|
8 |
|
9 |
+
# ===== 参数 =====
|
10 |
MODEL_NAME = "openai/whisper-large-v3"
|
11 |
BATCH_SIZE = 8
|
12 |
+
FILE_LIMIT_MB = 1000
|
13 |
+
MAX_SEG_DUR = 6.0
|
14 |
+
MAX_SEG_CHARS = 28
|
15 |
|
16 |
device = 0 if torch.cuda.is_available() else "cpu"
|
17 |
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
|
|
22 |
chunk_length_s=30,
|
23 |
device=device,
|
24 |
torch_dtype=dtype,
|
25 |
+
return_timestamps=True,
|
26 |
)
|
27 |
|
28 |
+
# ===== 时间戳格式化 =====
|
29 |
+
def _srt_timestamp(seconds: float | None) -> str:
|
|
|
30 |
if seconds is None or seconds < 0:
|
31 |
seconds = 0.0
|
32 |
ms = int(float(seconds) * 1000 + 0.5)
|
|
|
35 |
s, ms = divmod(ms, 1000)
|
36 |
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
|
37 |
|
38 |
+
# ===== 文本切分 + 自动补标点 =====
|
39 |
+
def _split_text_units(txt: str, max_seg_chars: int) -> list[str]:
|
40 |
+
strong = "。!?.!?"
|
41 |
+
units, cur = [], []
|
42 |
+
for ch in txt:
|
43 |
+
cur.append(ch)
|
44 |
+
if ch in strong:
|
45 |
+
units.append("".join(cur).strip())
|
46 |
+
cur = []
|
47 |
+
if cur:
|
48 |
+
units.append("".join(cur).strip())
|
49 |
+
|
50 |
+
refined = []
|
51 |
+
for u in units:
|
52 |
+
if len(u) <= max_seg_chars:
|
53 |
+
refined.append(u)
|
54 |
+
else:
|
55 |
+
# 长句继续切分,并自动补句号
|
56 |
+
for i in range(0, len(u), max_seg_chars):
|
57 |
+
piece = u[i:i+max_seg_chars].strip()
|
58 |
+
if piece and piece[-1] not in strong:
|
59 |
+
piece += "。"
|
60 |
+
refined.append(piece)
|
61 |
+
# 如果最后一段没有标点,补句号
|
62 |
+
if refined and refined[-1][-1] not in strong:
|
63 |
+
refined[-1] += "。"
|
64 |
+
return [x for x in refined if x]
|
65 |
+
|
66 |
+
# ===== chunks 转 SRT (无编号 + 自动标点) =====
|
67 |
+
def chunks_to_srt(chunks: list[dict],
|
68 |
+
max_seg_dur: float = MAX_SEG_DUR,
|
69 |
+
max_seg_chars: int = MAX_SEG_CHARS) -> str:
|
70 |
+
lines = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
for ch in chunks or []:
|
72 |
+
text = (ch.get("text") or "").strip()
|
73 |
+
if not text:
|
|
|
|
|
|
|
|
|
|
|
74 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
+
ts = ch.get("timestamp") or ch.get("timestamps") or [0.0, 2.0]
|
77 |
+
if isinstance(ts, (list, tuple)) and len(ts) == 2:
|
78 |
+
c_start, c_end = float(ts[0] or 0.0), float(ts[1] or 0.0)
|
79 |
+
else:
|
80 |
+
c_start, c_end = 0.0, 2.0
|
81 |
+
|
82 |
+
units = _split_text_units(text, max_seg_chars)
|
83 |
+
if not units:
|
84 |
+
units = [text]
|
85 |
+
|
86 |
+
total_chars = sum(len(u) for u in units) or 1
|
87 |
+
total_dur = max(c_end - c_start, 0.0)
|
88 |
+
|
89 |
+
cur_t = c_start
|
90 |
+
for u in units:
|
91 |
+
alloc = total_dur * (len(u) / total_chars)
|
92 |
+
alloc = max(alloc, 0.2)
|
93 |
+
if alloc <= max_seg_dur:
|
94 |
+
pieces = [u]
|
95 |
+
per = alloc
|
96 |
+
else:
|
97 |
+
# 再次切分,均匀分时长
|
98 |
+
smalls = [u[i:i+max_seg_chars] for i in range(0, len(u), max_seg_chars)]
|
99 |
+
pieces = [s.strip() + ("。" if not s.endswith("。") else "") for s in smalls if s.strip()]
|
100 |
+
per = min(max_seg_dur, alloc / max(1, len(pieces)))
|
101 |
+
|
102 |
+
for p in pieces:
|
103 |
+
if p and p[-1] not in "。!?.!?":
|
104 |
+
p += "。"
|
105 |
+
st = cur_t
|
106 |
+
en = st + per
|
107 |
+
lines.append(f"{_srt_timestamp(st)} --> {_srt_timestamp(en)}")
|
108 |
+
lines.append(p.strip())
|
109 |
+
lines.append("")
|
110 |
+
cur_t = en
|
111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
return "\n".join(lines).strip() + ("\n" if lines else "")
|
113 |
|
114 |
+
# ===== 上传音频 → SRT =====
|
115 |
@spaces.GPU
|
116 |
def transcribe_file_to_srt(audio_path: str, task: str):
|
117 |
if not audio_path:
|
|
|
127 |
text = result.get("text", "") or ""
|
128 |
chunks = result.get("chunks") or []
|
129 |
|
130 |
+
srt_str = chunks_to_srt(chunks)
|
131 |
+
if not srt_str and text.strip():
|
132 |
+
srt_str = "00:00:00,000 --> 00:00:02,000\n" + (text.strip() + "。") + "\n"
|
133 |
|
|
|
134 |
tmpdir = tempfile.mkdtemp(prefix="srt_")
|
135 |
base = os.path.splitext(os.path.basename(audio_path))[0] or "subtitle"
|
136 |
srt_path = os.path.join(tmpdir, f"{base}.srt")
|
137 |
with open(srt_path, "w", encoding="utf-8") as f:
|
138 |
f.write(srt_str)
|
139 |
|
|
|
140 |
return srt_str, srt_path
|
141 |
|
142 |
+
# ===== 界面 =====
|
143 |
demo = gr.Interface(
|
144 |
fn=transcribe_file_to_srt,
|
145 |
inputs=[
|
|
|
151 |
gr.File(label="Download SRT"),
|
152 |
],
|
153 |
title="Upload Audio → SRT Subtitle",
|
154 |
+
description=f"Upload an audio file to generate time-stamped SRT subtitles (auto punctuation, no numbering). Model: {MODEL_NAME}",
|
|
|
|
|
|
|
155 |
allow_flagging="never",
|
156 |
)
|
157 |
|