datxy commited on
Commit
5085deb
·
verified ·
1 Parent(s): cf205ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -133
app.py CHANGED
@@ -6,30 +6,32 @@ from transformers import pipeline
6
  import tempfile
7
  import os
8
 
9
- # ================== 可调参数 ==================
10
  MODEL_NAME = "openai/whisper-large-v3"
11
  BATCH_SIZE = 8
12
  FILE_LIMIT_MB = 1000
13
 
14
- TARGET_SENT_CHARS = 12 # 目标每句字数(中文场景)
15
- COMMA_EVERY = 0 # 如需更细粒度短停顿,可设 6/8(表示每 N 字加一个“,”并收句);0 表示关闭
16
- MAX_SEG_DUR = 6.0 # 每句最长时长(秒)
17
- MIN_PIECE_DUR = 0.30 # 每句最小时长(秒),避免闪烁
18
- STRONG_PUNCT = "。!?.!?"
19
 
 
20
  device = 0 if torch.cuda.is_available() else "cpu"
21
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
22
 
 
23
  asr = pipeline(
24
  task="automatic-speech-recognition",
25
  model=MODEL_NAME,
26
  chunk_length_s=30,
27
  device=device,
28
  torch_dtype=dtype,
29
- return_timestamps=True, # 仅需 chunk 级 (start,end)
30
  )
31
 
 
32
  def _ts(t: float | None) -> str:
 
33
  if t is None or t < 0:
34
  t = 0.0
35
  ms = int(float(t) * 1000 + 0.5)
@@ -39,17 +41,20 @@ def _ts(t: float | None) -> str:
39
  return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
40
 
41
  def _norm_chunks(chunks: list[dict]) -> list[dict]:
42
- """规范化 chunks: [{'text': str, 'start': float, 'end': float}]"""
 
 
 
43
  out = []
44
  for ch in chunks or []:
45
  text = (ch.get("text") or "").strip()
46
- ts = ch.get("timestamp") or ch.get("timestamps") or [0.0, 2.0]
47
  if not text:
48
  continue
49
  if isinstance(ts, (list, tuple)) and len(ts) == 2:
50
  s, e = float(ts[0] or 0.0), float(ts[1] or 0.0)
51
  else:
52
- s, e = 0.0, 2.0
53
  if e < s:
54
  e = s
55
  out.append({"text": text, "start": s, "end": e})
@@ -57,7 +62,7 @@ def _norm_chunks(chunks: list[dict]) -> list[dict]:
57
 
58
  def _char_timeline(chunk: dict) -> list[tuple[str, float, float]]:
59
  """
60
- 把一个 chunk 的文本按字符建立时间轴:
61
  返回 [(char, char_start, char_end), ...]
62
  """
63
  text = chunk["text"]
@@ -73,147 +78,65 @@ def _char_timeline(chunk: dict) -> list[tuple[str, float, float]]:
73
  cur = nxt
74
  return timeline
75
 
76
- def _segment_short_sentences(char_stream: list[tuple[str, float, float]],
77
- target_chars: int = TARGET_SENT_CHARS,
78
- comma_every: int = COMMA_EVERY,
79
- max_seg_dur: float = MAX_SEG_DUR) -> list[tuple[float, float, str]]:
80
  """
81
- 核心切分:
82
- - 累积字符直到遇到强标点 达到 target_chars
83
- - 可选:每 comma_every 个字符插入逗号并收句
84
- - 强标点永远并入本句,绝不产生“单独标点句”
85
- - 超长句再按时长 <= max_seg_dur 均匀切
86
  """
87
  segments = []
88
-
89
  buf_chars: list[str] = []
90
- buf_start = None
91
- last_char_end = None
92
- since_last_comma = 0
93
 
94
- def flush_sentence(force_punct=False):
95
- nonlocal buf_chars, buf_start, last_char_end, since_last_comma
96
  if not buf_chars:
97
  return
98
  text = "".join(buf_chars).strip()
99
- if not text:
100
- buf_chars = []
101
- buf_start = None
102
- since_last_comma = 0
103
- return
104
- # 保证句末有强标点
105
- if force_punct and text[-1] not in STRONG_PUNCT:
106
- text += "。"
107
- elif text[-1] not in STRONG_PUNCT:
108
- text += "。"
109
  st = buf_start if buf_start is not None else 0.0
110
- en = last_char_end if last_char_end is not None else st
111
- # 时长保护
112
- if en - st < MIN_PIECE_DUR:
113
- en = st + MIN_PIECE_DUR
114
  segments.append((st, en, text))
115
- buf_chars = []
116
- buf_start = None
117
- since_last_comma = 0
118
-
119
- def try_hard_wrap_long(st: float, en: float, text: str):
120
- """
121
- 如果单句太长(> max_seg_dur),按时长把文本均匀切成多块,每块 <= max_seg_dur,句末补句号。
122
- 使用 [st,en] 线性映射。
123
- """
124
- out = []
125
- dur = max(en - st, 0.0)
126
- if dur <= max_seg_dur:
127
- return [(st, en, text if text[-1] in STRONG_PUNCT else (text + "。"))]
128
 
129
- # 需要切成 k
130
- k = int(dur // max_seg_dur) + 1
131
- # 按字符再均匀切
132
- L = len(text)
133
- piece_len = max(L // k, 1)
134
- pos = 0
135
- for i in range(k):
136
- sub = text[pos:pos + piece_len]
137
- if not sub:
138
- continue
139
- sub_st = st + (i / k) * dur
140
- sub_en = st + ((i + 1) / k) * dur
141
- if sub[-1] not in STRONG_PUNCT:
142
- sub += "。"
143
- if sub_en - sub_st < MIN_PIECE_DUR:
144
- sub_en = sub_st + MIN_PIECE_DUR
145
- out.append((sub_st, sub_en, sub))
146
- pos += piece_len
147
- # 余数
148
- if pos < L:
149
- sub = text[pos:]
150
- sub_st = st + (len("".join([t for _,_,t in out])) / max(L,1)) * dur
151
- sub_en = en
152
- if sub and sub[-1] not in STRONG_PUNCT:
153
- sub += "。"
154
- if sub_en - sub_st < MIN_PIECE_DUR:
155
- sub_en = sub_st + MIN_PIECE_DUR
156
- out.append((sub_st, sub_en, sub))
157
- return out
158
-
159
- # 遍历逐字时间线
160
- for ch, ch_st, ch_en in char_stream:
161
  if buf_start is None:
162
- buf_start = ch_st
163
- buf_chars.append(ch)
164
- last_char_end = ch_en
165
- since_last_comma += 1
166
 
167
- # 强标点:直接并入当前句并收句
168
- if ch in STRONG_PUNCT:
169
- flush_sentence(force_punct=False)
170
- continue
171
-
172
- # 逗号式短停顿(可选)
173
- if comma_every and since_last_comma >= comma_every:
174
- # 只在当前累积达到目标一半以上时才加逗号,避免太碎
175
- if len(buf_chars) >= max(6, target_chars // 2):
176
- if buf_chars and buf_chars[-1] not in ",,、;;" and buf_chars[-1] not in STRONG_PUNCT:
177
- buf_chars.append(",")
178
- flush_sentence(force_punct=False)
179
- continue
180
- else:
181
- since_last_comma = 0 # 重置计数,继续攒
182
 
183
- # 达到目标长度:收句并补句号
184
- if len(buf_chars) >= target_chars:
185
- flush_sentence(force_punct=True)
186
 
187
  # 收尾
188
- flush_sentence(force_punct=True)
189
-
190
- # 二次处理:把任何超时长的句子再按时长切块(<= MAX_SEG_DUR)
191
- final_segments = []
192
- for st, en, tx in segments:
193
- if en - st > max_seg_dur:
194
- final_segments.extend(try_hard_wrap_long(st, en, tx))
195
- else:
196
- final_segments.append((st, en, tx if tx[-1] in STRONG_PUNCT else (tx + "。")))
197
-
198
- return final_segments
199
 
200
  def chunks_to_srt_no_number(chunks: list[dict]) -> str:
201
  """
202
- 外层封装:逐 chunk 建立字符时间线 合并切分 → 输出无编号 SRT
203
  """
204
  norm = _norm_chunks(chunks)
205
- # 构建全局逐字时间线(按 chunk 顺序拼接)
 
206
  char_stream = []
207
  for ch in norm:
208
  char_stream.extend(_char_timeline(ch))
209
 
210
- # 切分为短句片段
211
- segs = _segment_short_sentences(
212
- char_stream,
213
- target_chars=TARGET_SENT_CHARS,
214
- comma_every=COMMA_EVERY,
215
- max_seg_dur=MAX_SEG_DUR,
216
- )
217
 
218
  # 输出(无编号)
219
  lines = []
@@ -223,7 +146,7 @@ def chunks_to_srt_no_number(chunks: list[dict]) -> str:
223
  lines.append("")
224
  return "\n".join(lines).strip() + ("\n" if lines else "")
225
 
226
- # ================== 推理与UI ==================
227
  @spaces.GPU
228
  def transcribe_file_to_srt(audio_path: str, task: str):
229
  if not audio_path:
@@ -233,16 +156,22 @@ def transcribe_file_to_srt(audio_path: str, task: str):
233
  if size_mb > FILE_LIMIT_MB:
234
  raise gr.Error(f"文件过大:{size_mb:.1f} MB,超过限制 {FILE_LIMIT_MB} MB。")
235
  except OSError:
 
236
  pass
237
 
 
238
  result = asr(audio_path, batch_size=BATCH_SIZE, generate_kwargs={"task": task})
239
  text = (result.get("text") or "").strip()
240
  chunks = result.get("chunks") or []
241
 
 
242
  srt_str = chunks_to_srt_no_number(chunks)
 
 
243
  if not srt_str and text:
244
- srt_str = "00:00:00,000 --> 00:00:02,000\n" + (text + ("。" if text[-1] not in STRONG_PUNCT else "")) + "\n"
245
 
 
246
  tmpdir = tempfile.mkdtemp(prefix="srt_")
247
  base = os.path.splitext(os.path.basename(audio_path))[0] or "subtitle"
248
  srt_path = os.path.join(tmpdir, f"{base}.srt")
@@ -258,12 +187,12 @@ demo = gr.Interface(
258
  gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
259
  ],
260
  outputs=[
261
- gr.Textbox(label="Transcript (SRT Preview — short sentences, no numbering)", lines=18),
262
  gr.File(label="Download SRT"),
263
  ],
264
- title="Upload Audio → SRT (Short Sentences, No Numbering)",
265
- description=f"Character-timeline resegmentation. Natural short sentences like “他跟三国志他不一样。/ 他也是在那个基础上。/ …”. Model: {MODEL_NAME}",
266
  allow_flagging="never",
267
  )
268
 
269
- demo.queue().launch()
 
6
  import tempfile
7
  import os
8
 
9
+ # ================== 配置与可调参数 ==================
10
  MODEL_NAME = "openai/whisper-large-v3"
11
  BATCH_SIZE = 8
12
  FILE_LIMIT_MB = 1000
13
 
14
+ # —— 静音切分策略(本次需求)——
15
+ SILENCE_GAP = 0.2 # 相邻字符开始时间的间隔 >= 0.2s 触发切分
16
+ MIN_SEG_DUR = 0.30 # 每段最小时长,避免字幕闪烁
 
 
17
 
18
+ # 设备/精度
19
  device = 0 if torch.cuda.is_available() else "cpu"
20
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
21
 
22
+ # ASR 推理器:开启 chunk 级时间戳
23
  asr = pipeline(
24
  task="automatic-speech-recognition",
25
  model=MODEL_NAME,
26
  chunk_length_s=30,
27
  device=device,
28
  torch_dtype=dtype,
29
+ return_timestamps=True, # 返回 chunk 级别时间戳(start, end
30
  )
31
 
32
+ # ================== 工具函数 ==================
33
  def _ts(t: float | None) -> str:
34
+ """将秒数转为 SRT 时间戳 00:00:00,000"""
35
  if t is None or t < 0:
36
  t = 0.0
37
  ms = int(float(t) * 1000 + 0.5)
 
41
  return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
42
 
43
  def _norm_chunks(chunks: list[dict]) -> list[dict]:
44
+ """
45
+ 规范化 chunks: [{'text': str, 'start': float, 'end': float}]
46
+ 兼容不同字段名:'timestamp' 或 'timestamps'
47
+ """
48
  out = []
49
  for ch in chunks or []:
50
  text = (ch.get("text") or "").strip()
51
+ ts = ch.get("timestamp") or ch.get("timestamps") or [0.0, 0.0]
52
  if not text:
53
  continue
54
  if isinstance(ts, (list, tuple)) and len(ts) == 2:
55
  s, e = float(ts[0] or 0.0), float(ts[1] or 0.0)
56
  else:
57
+ s, e = 0.0, 0.0
58
  if e < s:
59
  e = s
60
  out.append({"text": text, "start": s, "end": e})
 
62
 
63
  def _char_timeline(chunk: dict) -> list[tuple[str, float, float]]:
64
  """
65
+ 对单个 chunk 逐字符线性插值时间:
66
  返回 [(char, char_start, char_end), ...]
67
  """
68
  text = chunk["text"]
 
78
  cur = nxt
79
  return timeline
80
 
81
+ # ================== 按静音间隔切分(不依赖标点) ==================
82
+ def _segment_by_silence(char_stream: list[tuple[str, float, float]],
83
+ silence_gap: float = SILENCE_GAP,
84
+ min_seg_dur: float = MIN_SEG_DUR) -> list[tuple[float, float, str]]:
85
  """
86
+ 按相邻字符的“开始时间间隔”切分:
87
+ - 若当前字符开始时间 上一字符结束时间 的差值 >= silence_gap → 立刻切一段
88
+ - 不做标点处理,所有字符原样拼接
89
+ - 每段最小时长保护
 
90
  """
91
  segments = []
 
92
  buf_chars: list[str] = []
93
+ buf_start: float | None = None
94
+ last_end: float | None = None
95
+ prev_end: float | None = None
96
 
97
+ def flush():
98
+ nonlocal buf_chars, buf_start, last_end
99
  if not buf_chars:
100
  return
101
  text = "".join(buf_chars).strip()
 
 
 
 
 
 
 
 
 
 
102
  st = buf_start if buf_start is not None else 0.0
103
+ en = last_end if last_end is not None else st
104
+ if en - st < min_seg_dur:
105
+ en = st + min_seg_dur
 
106
  segments.append((st, en, text))
107
+ buf_chars, buf_start, last_end = [], None, None
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
+ for ch, st, en in char_stream:
110
+ # 如果缓存为空,初始化起点
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  if buf_start is None:
112
+ buf_start = st
 
 
 
113
 
114
+ # 与上一字符“结束时间”的间隔用于判断静音切分
115
+ if prev_end is not None and (st - prev_end) >= silence_gap:
116
+ flush()
117
+ buf_start = st # 新段的起点从当前字符开始
 
 
 
 
 
 
 
 
 
 
 
118
 
119
+ buf_chars.append(ch)
120
+ last_end = en
121
+ prev_end = en
122
 
123
  # 收尾
124
+ flush()
125
+ return segments
 
 
 
 
 
 
 
 
 
126
 
127
  def chunks_to_srt_no_number(chunks: list[dict]) -> str:
128
  """
129
+ 流程:规范化 chunks全局逐字符时间线按静音切分 → 输出无编号 SRT
130
  """
131
  norm = _norm_chunks(chunks)
132
+
133
+ # 拼接全局字符流
134
  char_stream = []
135
  for ch in norm:
136
  char_stream.extend(_char_timeline(ch))
137
 
138
+ # 静音切分
139
+ segs = _segment_by_silence(char_stream)
 
 
 
 
 
140
 
141
  # 输出(无编号)
142
  lines = []
 
146
  lines.append("")
147
  return "\n".join(lines).strip() + ("\n" if lines else "")
148
 
149
+ # ================== 推理与 UI ==================
150
  @spaces.GPU
151
  def transcribe_file_to_srt(audio_path: str, task: str):
152
  if not audio_path:
 
156
  if size_mb > FILE_LIMIT_MB:
157
  raise gr.Error(f"文件过大:{size_mb:.1f} MB,超过限制 {FILE_LIMIT_MB} MB。")
158
  except OSError:
159
+ # 某些远端路径可能拿不到大小,忽略即可
160
  pass
161
 
162
+ # 运行 ASR
163
  result = asr(audio_path, batch_size=BATCH_SIZE, generate_kwargs={"task": task})
164
  text = (result.get("text") or "").strip()
165
  chunks = result.get("chunks") or []
166
 
167
+ # 基于静音切分生成 SRT(无编号)
168
  srt_str = chunks_to_srt_no_number(chunks)
169
+
170
+ # 兜底:若无 chunks,则给出整段兜底字幕
171
  if not srt_str and text:
172
+ srt_str = "00:00:00,000 --> 00:00:02,000\n" + text + "\n"
173
 
174
+ # 写入临时文件,供下载
175
  tmpdir = tempfile.mkdtemp(prefix="srt_")
176
  base = os.path.splitext(os.path.basename(audio_path))[0] or "subtitle"
177
  srt_path = os.path.join(tmpdir, f"{base}.srt")
 
187
  gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
188
  ],
189
  outputs=[
190
+ gr.Textbox(label="Transcript (SRT Preview — silence-based, no numbering)", lines=18),
191
  gr.File(label="Download SRT"),
192
  ],
193
+ title="Upload Audio → SRT (Silence-Based Segmentation, No Numbering)",
194
+ description=f"切分规则:相邻字符静音间隔 {SILENCE_GAP}s 则切段;不依赖标点。模型:{MODEL_NAME}",
195
  allow_flagging="never",
196
  )
197
 
198
+ demo.queue().launch()