datxy commited on
Commit
cf205ff
·
verified ·
1 Parent(s): 261c97f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +198 -87
app.py CHANGED
@@ -6,112 +6,224 @@ from transformers import pipeline
6
  import tempfile
7
  import os
8
 
9
- # ===== 参数 =====
10
  MODEL_NAME = "openai/whisper-large-v3"
11
  BATCH_SIZE = 8
12
  FILE_LIMIT_MB = 1000
13
- MAX_SEG_DUR = 6.0
14
- MAX_SEG_CHARS = 28
 
 
 
 
15
 
16
  device = 0 if torch.cuda.is_available() else "cpu"
17
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
18
 
19
- pipe = pipeline(
20
  task="automatic-speech-recognition",
21
  model=MODEL_NAME,
22
  chunk_length_s=30,
23
  device=device,
24
  torch_dtype=dtype,
25
- return_timestamps=True,
26
  )
27
 
28
- # ===== 时间戳格式化 =====
29
- def _srt_timestamp(seconds: float | None) -> str:
30
- if seconds is None or seconds < 0:
31
- seconds = 0.0
32
- ms = int(float(seconds) * 1000 + 0.5)
33
  h, ms = divmod(ms, 3600000)
34
  m, ms = divmod(ms, 60000)
35
  s, ms = divmod(ms, 1000)
36
  return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
37
 
38
- # ===== 文本切分 + 自动补标点 =====
39
- def _split_text_units(txt: str, max_seg_chars: int) -> list[str]:
40
- strong = "。!?.!?"
41
- units, cur = [], []
42
- for ch in txt:
43
- cur.append(ch)
44
- if ch in strong:
45
- units.append("".join(cur).strip())
46
- cur = []
47
- if cur:
48
- units.append("".join(cur).strip())
49
-
50
- refined = []
51
- for u in units:
52
- if len(u) <= max_seg_chars:
53
- refined.append(u)
54
- else:
55
- # 长句继续切分,并自动补句号
56
- for i in range(0, len(u), max_seg_chars):
57
- piece = u[i:i+max_seg_chars].strip()
58
- if piece and piece[-1] not in strong:
59
- piece += "。"
60
- refined.append(piece)
61
- # 如果最后一段没有标点,补句号
62
- if refined and refined[-1][-1] not in strong:
63
- refined[-1] += "。"
64
- return [x for x in refined if x]
65
-
66
- # ===== chunks 转 SRT (无编号 + 自动标点) =====
67
- def chunks_to_srt(chunks: list[dict],
68
- max_seg_dur: float = MAX_SEG_DUR,
69
- max_seg_chars: int = MAX_SEG_CHARS) -> str:
70
- lines = []
71
  for ch in chunks or []:
72
  text = (ch.get("text") or "").strip()
 
73
  if not text:
74
  continue
75
-
76
- ts = ch.get("timestamp") or ch.get("timestamps") or [0.0, 2.0]
77
  if isinstance(ts, (list, tuple)) and len(ts) == 2:
78
- c_start, c_end = float(ts[0] or 0.0), float(ts[1] or 0.0)
79
  else:
80
- c_start, c_end = 0.0, 2.0
81
-
82
- units = _split_text_units(text, max_seg_chars)
83
- if not units:
84
- units = [text]
85
-
86
- total_chars = sum(len(u) for u in units) or 1
87
- total_dur = max(c_end - c_start, 0.0)
88
-
89
- cur_t = c_start
90
- for u in units:
91
- alloc = total_dur * (len(u) / total_chars)
92
- alloc = max(alloc, 0.2)
93
- if alloc <= max_seg_dur:
94
- pieces = [u]
95
- per = alloc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  else:
97
- # 再次切分,均匀分时长
98
- smalls = [u[i:i+max_seg_chars] for i in range(0, len(u), max_seg_chars)]
99
- pieces = [s.strip() + ("。" if not s.endswith("。") else "") for s in smalls if s.strip()]
100
- per = min(max_seg_dur, alloc / max(1, len(pieces)))
101
-
102
- for p in pieces:
103
- if p and p[-1] not in "。!?.!?":
104
- p += "。"
105
- st = cur_t
106
- en = st + per
107
- lines.append(f"{_srt_timestamp(st)} --> {_srt_timestamp(en)}")
108
- lines.append(p.strip())
109
- lines.append("")
110
- cur_t = en
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  return "\n".join(lines).strip() + ("\n" if lines else "")
113
 
114
- # ===== 上传音频 → SRT =====
115
  @spaces.GPU
116
  def transcribe_file_to_srt(audio_path: str, task: str):
117
  if not audio_path:
@@ -123,13 +235,13 @@ def transcribe_file_to_srt(audio_path: str, task: str):
123
  except OSError:
124
  pass
125
 
126
- result = pipe(audio_path, batch_size=BATCH_SIZE, generate_kwargs={"task": task})
127
- text = result.get("text", "") or ""
128
  chunks = result.get("chunks") or []
129
 
130
- srt_str = chunks_to_srt(chunks)
131
- if not srt_str and text.strip():
132
- srt_str = "00:00:00,000 --> 00:00:02,000\n" + (text.strip() + "。") + "\n"
133
 
134
  tmpdir = tempfile.mkdtemp(prefix="srt_")
135
  base = os.path.splitext(os.path.basename(audio_path))[0] or "subtitle"
@@ -139,7 +251,6 @@ def transcribe_file_to_srt(audio_path: str, task: str):
139
 
140
  return srt_str, srt_path
141
 
142
- # ===== 界面 =====
143
  demo = gr.Interface(
144
  fn=transcribe_file_to_srt,
145
  inputs=[
@@ -147,12 +258,12 @@ demo = gr.Interface(
147
  gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
148
  ],
149
  outputs=[
150
- gr.Textbox(label="Transcript (SRT Preview)", lines=18),
151
  gr.File(label="Download SRT"),
152
  ],
153
- title="Upload Audio → SRT Subtitle",
154
- description=f"Upload an audio file to generate time-stamped SRT subtitles (auto punctuation, no numbering). Model: {MODEL_NAME}",
155
  allow_flagging="never",
156
  )
157
 
158
- demo.queue().launch()
 
6
  import tempfile
7
  import os
8
 
9
+ # ================== 可调参数 ==================
10
  MODEL_NAME = "openai/whisper-large-v3"
11
  BATCH_SIZE = 8
12
  FILE_LIMIT_MB = 1000
13
+
14
+ TARGET_SENT_CHARS = 12 # 目标每句字数(中文场景)
15
+ COMMA_EVERY = 0 # 如需更细粒度短停顿,可设 6/8(表示每 N 字加一个“,”并收句);0 表示关闭
16
+ MAX_SEG_DUR = 6.0 # 每句最长时长(秒)
17
+ MIN_PIECE_DUR = 0.30 # 每句最小时长(秒),避免闪烁
18
+ STRONG_PUNCT = "。!?.!?"
19
 
20
  device = 0 if torch.cuda.is_available() else "cpu"
21
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
22
 
23
+ asr = pipeline(
24
  task="automatic-speech-recognition",
25
  model=MODEL_NAME,
26
  chunk_length_s=30,
27
  device=device,
28
  torch_dtype=dtype,
29
+ return_timestamps=True, # 仅需 chunk 级 (start,end)
30
  )
31
 
32
+ def _ts(t: float | None) -> str:
33
+ if t is None or t < 0:
34
+ t = 0.0
35
+ ms = int(float(t) * 1000 + 0.5)
 
36
  h, ms = divmod(ms, 3600000)
37
  m, ms = divmod(ms, 60000)
38
  s, ms = divmod(ms, 1000)
39
  return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
40
 
41
+ def _norm_chunks(chunks: list[dict]) -> list[dict]:
42
+ """规范化 chunks: [{'text': str, 'start': float, 'end': float}]"""
43
+ out = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  for ch in chunks or []:
45
  text = (ch.get("text") or "").strip()
46
+ ts = ch.get("timestamp") or ch.get("timestamps") or [0.0, 2.0]
47
  if not text:
48
  continue
 
 
49
  if isinstance(ts, (list, tuple)) and len(ts) == 2:
50
+ s, e = float(ts[0] or 0.0), float(ts[1] or 0.0)
51
  else:
52
+ s, e = 0.0, 2.0
53
+ if e < s:
54
+ e = s
55
+ out.append({"text": text, "start": s, "end": e})
56
+ return out
57
+
58
+ def _char_timeline(chunk: dict) -> list[tuple[str, float, float]]:
59
+ """
60
+ 把一个 chunk 的文本按字符建立时间轴:
61
+ 返回 [(char, char_start, char_end), ...]
62
+ """
63
+ text = chunk["text"]
64
+ s, e = chunk["start"], chunk["end"]
65
+ dur = max(e - s, 0.0)
66
+ n = max(len(text), 1)
67
+ step = dur / n if n > 0 else 0.0
68
+ timeline = []
69
+ cur = s
70
+ for i, ch in enumerate(text):
71
+ nxt = s + (i + 1) * step
72
+ timeline.append((ch, cur, nxt))
73
+ cur = nxt
74
+ return timeline
75
+
76
+ def _segment_short_sentences(char_stream: list[tuple[str, float, float]],
77
+ target_chars: int = TARGET_SENT_CHARS,
78
+ comma_every: int = COMMA_EVERY,
79
+ max_seg_dur: float = MAX_SEG_DUR) -> list[tuple[float, float, str]]:
80
+ """
81
+ 核心切分:
82
+ - 累积字符直到遇到强标点 或 达到 target_chars
83
+ - 可选:每 comma_every 个字符插入逗号并收句
84
+ - 强标点永远并入本句,绝不产生“单独标点句”
85
+ - 超长句再按时长 <= max_seg_dur 均匀切
86
+ """
87
+ segments = []
88
+
89
+ buf_chars: list[str] = []
90
+ buf_start = None
91
+ last_char_end = None
92
+ since_last_comma = 0
93
+
94
+ def flush_sentence(force_punct=False):
95
+ nonlocal buf_chars, buf_start, last_char_end, since_last_comma
96
+ if not buf_chars:
97
+ return
98
+ text = "".join(buf_chars).strip()
99
+ if not text:
100
+ buf_chars = []
101
+ buf_start = None
102
+ since_last_comma = 0
103
+ return
104
+ # 保证句末有强标点
105
+ if force_punct and text[-1] not in STRONG_PUNCT:
106
+ text += "。"
107
+ elif text[-1] not in STRONG_PUNCT:
108
+ text += "。"
109
+ st = buf_start if buf_start is not None else 0.0
110
+ en = last_char_end if last_char_end is not None else st
111
+ # 时长保护
112
+ if en - st < MIN_PIECE_DUR:
113
+ en = st + MIN_PIECE_DUR
114
+ segments.append((st, en, text))
115
+ buf_chars = []
116
+ buf_start = None
117
+ since_last_comma = 0
118
+
119
+ def try_hard_wrap_long(st: float, en: float, text: str):
120
+ """
121
+ 如果单句太长(> max_seg_dur),按时长把文本均匀切成多块,每块 <= max_seg_dur,句末补句号。
122
+ 使用 [st,en] 线性映射。
123
+ """
124
+ out = []
125
+ dur = max(en - st, 0.0)
126
+ if dur <= max_seg_dur:
127
+ return [(st, en, text if text[-1] in STRONG_PUNCT else (text + "。"))]
128
+
129
+ # 需要切成 k 块
130
+ k = int(dur // max_seg_dur) + 1
131
+ # 按字符再均匀切
132
+ L = len(text)
133
+ piece_len = max(L // k, 1)
134
+ pos = 0
135
+ for i in range(k):
136
+ sub = text[pos:pos + piece_len]
137
+ if not sub:
138
+ continue
139
+ sub_st = st + (i / k) * dur
140
+ sub_en = st + ((i + 1) / k) * dur
141
+ if sub[-1] not in STRONG_PUNCT:
142
+ sub += "。"
143
+ if sub_en - sub_st < MIN_PIECE_DUR:
144
+ sub_en = sub_st + MIN_PIECE_DUR
145
+ out.append((sub_st, sub_en, sub))
146
+ pos += piece_len
147
+ # 余数
148
+ if pos < L:
149
+ sub = text[pos:]
150
+ sub_st = st + (len("".join([t for _,_,t in out])) / max(L,1)) * dur
151
+ sub_en = en
152
+ if sub and sub[-1] not in STRONG_PUNCT:
153
+ sub += "。"
154
+ if sub_en - sub_st < MIN_PIECE_DUR:
155
+ sub_en = sub_st + MIN_PIECE_DUR
156
+ out.append((sub_st, sub_en, sub))
157
+ return out
158
+
159
+ # 遍历逐字时间线
160
+ for ch, ch_st, ch_en in char_stream:
161
+ if buf_start is None:
162
+ buf_start = ch_st
163
+ buf_chars.append(ch)
164
+ last_char_end = ch_en
165
+ since_last_comma += 1
166
+
167
+ # 强标点:直接并入当前句并收句
168
+ if ch in STRONG_PUNCT:
169
+ flush_sentence(force_punct=False)
170
+ continue
171
+
172
+ # 逗号式短停顿(可选)
173
+ if comma_every and since_last_comma >= comma_every:
174
+ # 只在当前累积达到目标一半以上时才加逗号,避免太碎
175
+ if len(buf_chars) >= max(6, target_chars // 2):
176
+ if buf_chars and buf_chars[-1] not in ",,、;;" and buf_chars[-1] not in STRONG_PUNCT:
177
+ buf_chars.append(",")
178
+ flush_sentence(force_punct=False)
179
+ continue
180
  else:
181
+ since_last_comma = 0 # 重置计数,继续攒
182
+
183
+ # 达到目标长度:收句并补句号
184
+ if len(buf_chars) >= target_chars:
185
+ flush_sentence(force_punct=True)
186
+
187
+ # 收尾
188
+ flush_sentence(force_punct=True)
 
 
 
 
 
 
189
 
190
+ # 二次处理:把任何超时长的句子再按时长切块(<= MAX_SEG_DUR)
191
+ final_segments = []
192
+ for st, en, tx in segments:
193
+ if en - st > max_seg_dur:
194
+ final_segments.extend(try_hard_wrap_long(st, en, tx))
195
+ else:
196
+ final_segments.append((st, en, tx if tx[-1] in STRONG_PUNCT else (tx + "。")))
197
+
198
+ return final_segments
199
+
200
+ def chunks_to_srt_no_number(chunks: list[dict]) -> str:
201
+ """
202
+ 外层封装:逐 chunk 建立字符时间线 → 合并 → 切分 → 输出无编号 SRT。
203
+ """
204
+ norm = _norm_chunks(chunks)
205
+ # 构建全局逐字时间线(按 chunk 顺序拼接)
206
+ char_stream = []
207
+ for ch in norm:
208
+ char_stream.extend(_char_timeline(ch))
209
+
210
+ # 切分为短句片段
211
+ segs = _segment_short_sentences(
212
+ char_stream,
213
+ target_chars=TARGET_SENT_CHARS,
214
+ comma_every=COMMA_EVERY,
215
+ max_seg_dur=MAX_SEG_DUR,
216
+ )
217
+
218
+ # 输出(无编号)
219
+ lines = []
220
+ for st, en, tx in segs:
221
+ lines.append(f"{_ts(st)} --> {_ts(en)}")
222
+ lines.append(tx.strip())
223
+ lines.append("")
224
  return "\n".join(lines).strip() + ("\n" if lines else "")
225
 
226
+ # ================== 推理与UI ==================
227
  @spaces.GPU
228
  def transcribe_file_to_srt(audio_path: str, task: str):
229
  if not audio_path:
 
235
  except OSError:
236
  pass
237
 
238
+ result = asr(audio_path, batch_size=BATCH_SIZE, generate_kwargs={"task": task})
239
+ text = (result.get("text") or "").strip()
240
  chunks = result.get("chunks") or []
241
 
242
+ srt_str = chunks_to_srt_no_number(chunks)
243
+ if not srt_str and text:
244
+ srt_str = "00:00:00,000 --> 00:00:02,000\n" + (text + ("。" if text[-1] not in STRONG_PUNCT else "")) + "\n"
245
 
246
  tmpdir = tempfile.mkdtemp(prefix="srt_")
247
  base = os.path.splitext(os.path.basename(audio_path))[0] or "subtitle"
 
251
 
252
  return srt_str, srt_path
253
 
 
254
  demo = gr.Interface(
255
  fn=transcribe_file_to_srt,
256
  inputs=[
 
258
  gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
259
  ],
260
  outputs=[
261
+ gr.Textbox(label="Transcript (SRT Preview — short sentences, no numbering)", lines=18),
262
  gr.File(label="Download SRT"),
263
  ],
264
+ title="Upload Audio → SRT (Short Sentences, No Numbering)",
265
+ description=f"Character-timeline resegmentation. Natural short sentences like “他跟三国志他不一样。/ 他也是在那个基础上。/ …”. Model: {MODEL_NAME}",
266
  allow_flagging="never",
267
  )
268
 
269
+ demo.queue().launch()