datxy commited on
Commit
261c97f
·
verified ·
1 Parent(s): cdddc7e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -149
app.py CHANGED
@@ -2,32 +2,30 @@
2
  import spaces
3
  import torch
4
  import gradio as gr
 
5
  import tempfile
6
  import os
7
 
8
- # ====== 可选:faster-whisper,若存在则自动使用(vad_filter=True)======
9
- USE_FASTER_WHISPER = True
10
- try:
11
- if USE_FASTER_WHISPER:
12
- from faster_whisper import WhisperModel # type: ignore
13
- _HAS_FW = True
14
- except Exception:
15
- _HAS_FW = False
16
-
17
- # ====== 可调参数 ======
18
- ASR_MODEL = "openai/whisper-large-v3" # Transformers 管线备用模型
19
  BATCH_SIZE = 8
20
  FILE_LIMIT_MB = 1000
21
- MAX_SEG_DUR = 6.0 # 每行最大时长(秒)
22
- MAX_SEG_CHARS = 28 # 每行最大字符数(中文可适当减小)
23
- PAUSE_LONG = 0.9 # 认为是“句末停顿”的阈值(秒)→ 补句号
24
- PAUSE_SHORT = 0.45 # 认为是“轻停顿”的阈值(秒)→ 补逗号
25
- MIN_PIECE_DUR = 0.2 # 每小片的最小时长,避免 0
26
-
27
- device = "cuda" if torch.cuda.is_available() else "cpu"
28
- dtype = "float16" if torch.cuda.is_available() else "float32"
 
 
 
 
 
 
29
 
30
- # ====== 时间戳工具 ======
31
  def _srt_timestamp(seconds: float | None) -> str:
32
  if seconds is None or seconds < 0:
33
  seconds = 0.0
@@ -37,25 +35,9 @@ def _srt_timestamp(seconds: float | None) -> str:
37
  s, ms = divmod(ms, 1000)
38
  return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
39
 
40
- # ====== 简易标点:根据停顿/片段长度补标点 ======
41
- _STRONG = "。!?.!?"
42
- def _ensure_sentence_punct(s: str) -> str:
43
- s = s.strip()
44
- if not s:
45
- return s
46
- if s[-1] not in _STRONG:
47
- s += "。"
48
- return s
49
-
50
- def _maybe_add_comma(s: str) -> str:
51
- s = s.rstrip()
52
- if s and s[-1] not in _STRONG + ",,、;;":
53
- s += ","
54
- return s
55
-
56
- # ====== 文本切分(标点优先,其次按长度兜底)=====
57
  def _split_text_units(txt: str, max_seg_chars: int) -> list[str]:
58
- strong = _STRONG
59
  units, cur = [], []
60
  for ch in txt:
61
  cur.append(ch)
@@ -70,16 +52,22 @@ def _split_text_units(txt: str, max_seg_chars: int) -> list[str]:
70
  if len(u) <= max_seg_chars:
71
  refined.append(u)
72
  else:
73
- # 粗切为不超过 max_seg_chars 的块
74
  for i in range(0, len(u), max_seg_chars):
75
- refined.append(u[i:i + max_seg_chars].strip())
 
 
 
 
 
 
76
  return [x for x in refined if x]
77
 
78
- # ====== 基于 chunk 的字符比例细分 + 简易标点 ======
79
- def _chunks_to_srt_no_number(chunks, max_seg_dur=MAX_SEG_DUR, max_seg_chars=MAX_SEG_CHARS) -> str:
 
 
80
  lines = []
81
- prev_end = None
82
-
83
  for ch in chunks or []:
84
  text = (ch.get("text") or "").strip()
85
  if not text:
@@ -91,13 +79,6 @@ def _chunks_to_srt_no_number(chunks, max_seg_dur=MAX_SEG_DUR, max_seg_chars=MAX_
91
  else:
92
  c_start, c_end = 0.0, 2.0
93
 
94
- # 根据与上一段的“停顿”补逗号/句号(软提示,真正的标点在行尾处理)
95
- if prev_end is not None:
96
- gap = max(c_start - prev_end, 0.0)
97
- # 我们不直接把标点写进上一行文本,而是作为划分参考
98
- # 具体标点在每行最终输出前处理(见下方)
99
- prev_end = c_end
100
-
101
  units = _split_text_units(text, max_seg_chars)
102
  if not units:
103
  units = [text]
@@ -107,100 +88,32 @@ def _chunks_to_srt_no_number(chunks, max_seg_dur=MAX_SEG_DUR, max_seg_chars=MAX_
107
 
108
  cur_t = c_start
109
  for u in units:
110
- alloc = max(total_dur * (len(u) / total_chars), MIN_PIECE_DUR)
111
-
112
- # 若超长,继续细分为不超过 max_seg_chars 的片,并均分时长
113
  if alloc <= max_seg_dur:
114
  pieces = [u]
115
  per = alloc
116
  else:
117
- smalls = [u[i:i + max_seg_chars] for i in range(0, len(u), max_seg_chars)]
118
- pieces = [s for s in smalls if s.strip()]
119
- per = max(min(max_seg_dur, alloc / max(1, len(pieces))), MIN_PIECE_DUR)
120
-
121
- for i, p in enumerate(pieces):
 
 
 
122
  st = cur_t
123
  en = st + per
124
-
125
- frag = p.strip()
126
-
127
- # 行尾自动补标点:更自然
128
- # 规则:
129
- # 1) 若该小片接近一句末尾(片段较长 或 达到 max_seg_chars)→ 句号
130
- # 2) 否则可能是轻停顿 → 逗号
131
- # 3) 若已有强标点,保持不动
132
- if frag and frag[-1] not in _STRONG:
133
- # 按时长和长度推断语气
134
- if per >= PAUSE_LONG or len(frag) >= max_seg_chars * 0.9:
135
- frag = _ensure_sentence_punct(frag)
136
- elif per >= PAUSE_SHORT:
137
- frag = _maybe_add_comma(frag)
138
- # else: 很短的片段不强制加标点(避免过密)
139
-
140
  lines.append(f"{_srt_timestamp(st)} --> {_srt_timestamp(en)}")
141
- lines.append(frag)
142
  lines.append("")
143
  cur_t = en
144
 
145
  return "\n".join(lines).strip() + ("\n" if lines else "")
146
 
147
- # ====== 如果有 faster-whisper:使用 VAD 直接拿到更干净的分段 ======
148
- def _fw_transcribe_to_chunks(audio_path: str):
149
- """
150
- 使用 faster-whisper + vad_filter=True 返回“类 chunk”结构:
151
- [{'text': '...', 'timestamp': [start, end]}...]
152
- """
153
- # 模型建议设为 medium/large-v3;这里默认 large-v3
154
- model = WhisperModel("large-v3", device=device, compute_type="auto")
155
- # vad_filter=True:用 Silero VAD 过滤静音/噪音
156
- segments, _info = model.transcribe(
157
- audio_path,
158
- vad_filter=True,
159
- vad_parameters=dict(min_silence_duration_ms=int(PAUSE_SHORT * 1000)),
160
- beam_size=5,
161
- best_of=5,
162
- )
163
- chunks = []
164
- for seg in segments:
165
- chunks.append({
166
- "text": seg.text.strip(),
167
- "timestamp": [float(seg.start or 0.0), float(seg.end or 0.0)],
168
- })
169
- return chunks
170
-
171
- # ====== Transformers 回退方案 ======
172
- from transformers import pipeline as hf_pipeline
173
- def _hf_transcribe_to_chunks(audio_path: str):
174
- pipe = hf_pipeline(
175
- task="automatic-speech-recognition",
176
- model=ASR_MODEL,
177
- chunk_length_s=30,
178
- device=0 if torch.cuda.is_available() else "cpu",
179
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
180
- return_timestamps=True,
181
- )
182
- result = pipe(audio_path, batch_size=BATCH_SIZE, generate_kwargs={"task": "transcribe"})
183
- # 期望返回形如:{"text": "…", "chunks": [{"text": "...", "timestamp": (s,e)}, ...]}
184
- chunks = result.get("chunks") or []
185
- # 如无 chunks,用整段兜底
186
- if not chunks:
187
- total_text = (result.get("text") or "").strip()
188
- if total_text:
189
- chunks = [{"text": total_text, "timestamp": (0.0, max(MAX_SEG_DUR, 2.0))}]
190
- # 统一结构
191
- norm = []
192
- for ch in chunks:
193
- ts = ch.get("timestamp") or ch.get("timestamps") or [0.0, 2.0]
194
- if isinstance(ts, (list, tuple)) and len(ts) == 2:
195
- s, e = ts
196
- else:
197
- s, e = 0.0, 2.0
198
- norm.append({"text": (ch.get("text") or "").strip(), "timestamp": [float(s or 0.0), float(e or 0.0)]})
199
- return norm
200
-
201
- # ====== 主函数:上传音频 → SRT(无编号 + 简易标点 + 可选 VAD)======
202
  @spaces.GPU
203
- def transcribe_file_to_srt(audio_path: str, task: str, use_vad: bool):
204
  if not audio_path:
205
  raise gr.Error("请先上传音频文件。")
206
  try:
@@ -210,17 +123,14 @@ def transcribe_file_to_srt(audio_path: str, task: str, use_vad: bool):
210
  except OSError:
211
  pass
212
 
213
- # 选择后端:优先 faster-whisper + VAD,其次 Transformers
214
- if use_vad and _HAS_FW:
215
- chunks = _fw_transcribe_to_chunks(audio_path)
216
- else:
217
- chunks = _hf_transcribe_to_chunks(audio_path)
218
 
219
- srt_str = _chunks_to_srt_no_number(chunks, MAX_SEG_DUR, MAX_SEG_CHARS)
220
- if not srt_str:
221
- srt_str = "00:00:00,000 --> 00:00:02,000\n(空)\n"
222
 
223
- # 输出文件
224
  tmpdir = tempfile.mkdtemp(prefix="srt_")
225
  base = os.path.splitext(os.path.basename(audio_path))[0] or "subtitle"
226
  srt_path = os.path.join(tmpdir, f"{base}.srt")
@@ -229,23 +139,19 @@ def transcribe_file_to_srt(audio_path: str, task: str, use_vad: bool):
229
 
230
  return srt_str, srt_path
231
 
232
- # ====== UI ======
233
  demo = gr.Interface(
234
  fn=transcribe_file_to_srt,
235
  inputs=[
236
  gr.Audio(sources="upload", type="filepath", label="Audio file"),
237
  gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
238
- gr.Checkbox(label="Use VAD + Auto Punctuation (simple)", value=True),
239
  ],
240
  outputs=[
241
- gr.Textbox(label="SRT Preview (no numbering, auto punctuation)", lines=18),
242
  gr.File(label="Download SRT"),
243
  ],
244
- title="Upload Audio → SRT (VAD + Auto Punctuation, No Numbering)",
245
- description=(
246
- "Optional VAD (via faster-whisper) for cleaner segments. "
247
- "Adds simple punctuation by pause/length. Adjustable MAX_SEG_DUR / MAX_SEG_CHARS / PAUSE_*."
248
- ),
249
  allow_flagging="never",
250
  )
251
 
 
2
  import spaces
3
  import torch
4
  import gradio as gr
5
+ from transformers import pipeline
6
  import tempfile
7
  import os
8
 
9
+ # ===== 参数 =====
10
+ MODEL_NAME = "openai/whisper-large-v3"
 
 
 
 
 
 
 
 
 
11
  BATCH_SIZE = 8
12
  FILE_LIMIT_MB = 1000
13
+ MAX_SEG_DUR = 6.0
14
+ MAX_SEG_CHARS = 28
15
+
16
+ device = 0 if torch.cuda.is_available() else "cpu"
17
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
18
+
19
+ pipe = pipeline(
20
+ task="automatic-speech-recognition",
21
+ model=MODEL_NAME,
22
+ chunk_length_s=30,
23
+ device=device,
24
+ torch_dtype=dtype,
25
+ return_timestamps=True,
26
+ )
27
 
28
+ # ===== 时间戳格式化 =====
29
  def _srt_timestamp(seconds: float | None) -> str:
30
  if seconds is None or seconds < 0:
31
  seconds = 0.0
 
35
  s, ms = divmod(ms, 1000)
36
  return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
37
 
38
+ # ===== 文本切分 + 自动补标点 =====
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def _split_text_units(txt: str, max_seg_chars: int) -> list[str]:
40
+ strong = "。!?.!?"
41
  units, cur = [], []
42
  for ch in txt:
43
  cur.append(ch)
 
52
  if len(u) <= max_seg_chars:
53
  refined.append(u)
54
  else:
55
+ # 长句继续切分,并自动补句号
56
  for i in range(0, len(u), max_seg_chars):
57
+ piece = u[i:i+max_seg_chars].strip()
58
+ if piece and piece[-1] not in strong:
59
+ piece += "。"
60
+ refined.append(piece)
61
+ # 如果最后一段没有标点,补句号
62
+ if refined and refined[-1][-1] not in strong:
63
+ refined[-1] += "。"
64
  return [x for x in refined if x]
65
 
66
+ # ===== chunks SRT (无编号 + 自动标点) =====
67
+ def chunks_to_srt(chunks: list[dict],
68
+ max_seg_dur: float = MAX_SEG_DUR,
69
+ max_seg_chars: int = MAX_SEG_CHARS) -> str:
70
  lines = []
 
 
71
  for ch in chunks or []:
72
  text = (ch.get("text") or "").strip()
73
  if not text:
 
79
  else:
80
  c_start, c_end = 0.0, 2.0
81
 
 
 
 
 
 
 
 
82
  units = _split_text_units(text, max_seg_chars)
83
  if not units:
84
  units = [text]
 
88
 
89
  cur_t = c_start
90
  for u in units:
91
+ alloc = total_dur * (len(u) / total_chars)
92
+ alloc = max(alloc, 0.2)
 
93
  if alloc <= max_seg_dur:
94
  pieces = [u]
95
  per = alloc
96
  else:
97
+ # 再次切分,均匀分时长
98
+ smalls = [u[i:i+max_seg_chars] for i in range(0, len(u), max_seg_chars)]
99
+ pieces = [s.strip() + ("。" if not s.endswith("。") else "") for s in smalls if s.strip()]
100
+ per = min(max_seg_dur, alloc / max(1, len(pieces)))
101
+
102
+ for p in pieces:
103
+ if p and p[-1] not in "。!?.!?":
104
+ p += "。"
105
  st = cur_t
106
  en = st + per
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  lines.append(f"{_srt_timestamp(st)} --> {_srt_timestamp(en)}")
108
+ lines.append(p.strip())
109
  lines.append("")
110
  cur_t = en
111
 
112
  return "\n".join(lines).strip() + ("\n" if lines else "")
113
 
114
+ # ===== 上传音频 SRT =====
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  @spaces.GPU
116
+ def transcribe_file_to_srt(audio_path: str, task: str):
117
  if not audio_path:
118
  raise gr.Error("请先上传音频文件。")
119
  try:
 
123
  except OSError:
124
  pass
125
 
126
+ result = pipe(audio_path, batch_size=BATCH_SIZE, generate_kwargs={"task": task})
127
+ text = result.get("text", "") or ""
128
+ chunks = result.get("chunks") or []
 
 
129
 
130
+ srt_str = chunks_to_srt(chunks)
131
+ if not srt_str and text.strip():
132
+ srt_str = "00:00:00,000 --> 00:00:02,000\n" + (text.strip() + "。") + "\n"
133
 
 
134
  tmpdir = tempfile.mkdtemp(prefix="srt_")
135
  base = os.path.splitext(os.path.basename(audio_path))[0] or "subtitle"
136
  srt_path = os.path.join(tmpdir, f"{base}.srt")
 
139
 
140
  return srt_str, srt_path
141
 
142
+ # ===== 界面 =====
143
  demo = gr.Interface(
144
  fn=transcribe_file_to_srt,
145
  inputs=[
146
  gr.Audio(sources="upload", type="filepath", label="Audio file"),
147
  gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
 
148
  ],
149
  outputs=[
150
+ gr.Textbox(label="Transcript (SRT Preview)", lines=18),
151
  gr.File(label="Download SRT"),
152
  ],
153
+ title="Upload Audio → SRT Subtitle",
154
+ description=f"Upload an audio file to generate time-stamped SRT subtitles (auto punctuation, no numbering). Model: {MODEL_NAME}",
 
 
 
155
  allow_flagging="never",
156
  )
157