datxy commited on
Commit
a33f970
·
verified ·
1 Parent(s): caa6c38

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -97
app.py CHANGED
@@ -1,16 +1,17 @@
 
1
  import spaces
2
  import torch
3
  import gradio as gr
4
  from transformers import pipeline
5
-
6
  import tempfile
7
  import os
8
- from datetime import timedelta
9
 
10
- # ===== 配置 =====
11
  MODEL_NAME = "openai/whisper-large-v3"
12
  BATCH_SIZE = 8
13
- FILE_LIMIT_MB = 1000 # 最大 1000MB
 
 
14
 
15
  device = 0 if torch.cuda.is_available() else "cpu"
16
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
@@ -21,12 +22,11 @@ pipe = pipeline(
21
  chunk_length_s=30,
22
  device=device,
23
  torch_dtype=dtype,
24
- return_timestamps="word", # 关键:逐词时间戳,便于细分
25
  )
26
 
27
- # ===== 工具函数:时间戳/SRT =====
28
- def _srt_timestamp(seconds):
29
- """秒 -> SRT 时间戳 00:00:00,000。None/负数时归零。"""
30
  if seconds is None or seconds < 0:
31
  seconds = 0.0
32
  ms = int(float(seconds) * 1000 + 0.5)
@@ -35,92 +35,83 @@ def _srt_timestamp(seconds):
35
  s, ms = divmod(ms, 1000)
36
  return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
37
 
38
- def chunks_to_srt(chunks, text_fallback="", max_seg_dur=6.0, max_seg_chars=42):
39
- """
40
- 用逐词时间戳把长 chunk 细分成更短的 SRT 行:
41
- - 每行最长持续 max_seg_dur
42
- - 或字符数约 max_seg_chars
43
- - 遇到句末标点(。!?.!?)优先断句
44
- """
45
- segs = []
46
- cur_words = []
47
- cur_start = None
48
- cur_len = 0
49
-
50
- def flush_seg():
51
- nonlocal cur_words, cur_start, cur_len
52
- if not cur_words:
53
- return
54
- # 兼容多种时间戳字段
55
- st = cur_start if cur_start is not None else cur_words[0].get("start", 0.0)
56
- en = cur_words[-1].get("end", cur_words[-1].get("timestamp", [0.0, 0.0])[-1] if isinstance(cur_words[-1].get("timestamp"), (list, tuple)) else 0.0)
57
- if isinstance(st, (list, tuple)): st = st[0]
58
- if isinstance(en, (list, tuple)): en = en[-1]
59
- text = "".join(w.get("word", "").strip() for w in cur_words).strip()
60
- if text:
61
- segs.append((float(st or 0.0), float(en or 0.0), text))
62
- cur_words = []
63
- cur_start = None
64
- cur_len = 0
65
-
66
- def maybe_flush(force=False, strong_punct=False):
67
- if not cur_words:
68
- return
69
- st = cur_start if cur_start is not None else cur_words[0].get("start", 0.0)
70
- en = cur_words[-1].get("end", cur_words[-1].get("timestamp", [0.0, 0.0])[-1] if isinstance(cur_words[-1].get("timestamp"), (list, tuple)) else 0.0)
71
- if isinstance(st, (list, tuple)): st = st[0]
72
- if isinstance(en, (list, tuple)): en = en[-1]
73
- dur = float((en or 0.0) - (st or 0.0))
74
- if force or strong_punct or dur >= max_seg_dur or cur_len >= max_seg_chars:
75
- flush_seg()
76
-
77
- # 汇总所有词
78
- all_words = []
79
  for ch in chunks or []:
80
- words = ch.get("words") or []
81
- if not words and ch.get("text"):
82
- ts = ch.get("timestamp") or ch.get("timestamps") or [0.0, 2.0]
83
- if isinstance(ts, (list, tuple)) and len(ts) == 2:
84
- all_words.append({"word": ch["text"], "start": ts[0], "end": ts[1]})
85
- else:
86
- all_words.append({"word": ch["text"], "start": 0.0, "end": 2.0})
87
  continue
88
- for w in words:
89
- token = (w.get("word") or "").replace("\n", " ")
90
- start = w.get("start")
91
- end = w.get("end")
92
- if (start is None or end is None) and isinstance(w.get("timestamp"), (list, tuple)) and len(w["timestamp"]) == 2:
93
- start, end = w["timestamp"]
94
- all_words.append({"word": token, "start": start, "end": end})
95
-
96
- # 若依旧拿不到逐词,回退整段文本
97
- if not all_words and text_fallback.strip():
98
- all_words = [{"word": text_fallback.strip(), "start": 0.0, "end": max_seg_dur}]
99
-
100
- # 按规则切分
101
- for w in all_words:
102
- token = w.get("word", "")
103
- if not token:
104
- continue
105
- if cur_start is None:
106
- cur_start = w.get("start", 0.0)
107
- cur_words.append(w)
108
- cur_len += len(token)
109
- strong = token.endswith(("。", "!", "?", ".", "!", "?"))
110
- maybe_flush(force=False, strong_punct=strong)
111
 
112
- maybe_flush(force=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
- # 生成 SRT
115
- lines = []
116
- for i, (st, en, txt) in enumerate(segs, 1):
117
- lines.append(str(i))
118
- lines.append(f"{_srt_timestamp(st)} --> {_srt_timestamp(en)}")
119
- lines.append(txt)
120
- lines.append("")
121
  return "\n".join(lines).strip() + ("\n" if lines else "")
122
 
123
- # ===== 上传音频 -> SRT 导出 =====
124
  @spaces.GPU
125
  def transcribe_file_to_srt(audio_path: str, task: str):
126
  if not audio_path:
@@ -136,20 +127,19 @@ def transcribe_file_to_srt(audio_path: str, task: str):
136
  text = result.get("text", "") or ""
137
  chunks = result.get("chunks") or []
138
 
139
- # SRT(预览即为 SRT)
140
- srt_str = chunks_to_srt(chunks, text_fallback=text)
 
141
 
142
- # 写入临时文件供下载
143
  tmpdir = tempfile.mkdtemp(prefix="srt_")
144
  base = os.path.splitext(os.path.basename(audio_path))[0] or "subtitle"
145
  srt_path = os.path.join(tmpdir, f"{base}.srt")
146
  with open(srt_path, "w", encoding="utf-8") as f:
147
  f.write(srt_str)
148
 
149
- # 第一个输出显示 SRT 字符串,第二个输出提供下载
150
  return srt_str, srt_path
151
 
152
- # ===== Gradio 界面 =====
153
  demo = gr.Interface(
154
  fn=transcribe_file_to_srt,
155
  inputs=[
@@ -161,10 +151,7 @@ demo = gr.Interface(
161
  gr.File(label="Download SRT"),
162
  ],
163
  title="Upload Audio → SRT Subtitle",
164
- description=(
165
- "Upload an audio file to generate time-stamped SRT subtitles. "
166
- f"Backed by [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME})."
167
- ),
168
  allow_flagging="never",
169
  )
170
 
 
1
+ # app.py
2
  import spaces
3
  import torch
4
  import gradio as gr
5
  from transformers import pipeline
 
6
  import tempfile
7
  import os
 
8
 
9
+ # ===== 参数 =====
10
  MODEL_NAME = "openai/whisper-large-v3"
11
  BATCH_SIZE = 8
12
+ FILE_LIMIT_MB = 1000
13
+ MAX_SEG_DUR = 6.0
14
+ MAX_SEG_CHARS = 28
15
 
16
  device = 0 if torch.cuda.is_available() else "cpu"
17
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
 
22
  chunk_length_s=30,
23
  device=device,
24
  torch_dtype=dtype,
25
+ return_timestamps=True,
26
  )
27
 
28
+ # ===== 时间戳格式化 =====
29
+ def _srt_timestamp(seconds: float | None) -> str:
 
30
  if seconds is None or seconds < 0:
31
  seconds = 0.0
32
  ms = int(float(seconds) * 1000 + 0.5)
 
35
  s, ms = divmod(ms, 1000)
36
  return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
37
 
38
+ # ===== 文本切分 + 自动补标点 =====
39
+ def _split_text_units(txt: str, max_seg_chars: int) -> list[str]:
40
+ strong = "。!?.!?"
41
+ units, cur = [], []
42
+ for ch in txt:
43
+ cur.append(ch)
44
+ if ch in strong:
45
+ units.append("".join(cur).strip())
46
+ cur = []
47
+ if cur:
48
+ units.append("".join(cur).strip())
49
+
50
+ refined = []
51
+ for u in units:
52
+ if len(u) <= max_seg_chars:
53
+ refined.append(u)
54
+ else:
55
+ # 长句继续切分,并自动补句号
56
+ for i in range(0, len(u), max_seg_chars):
57
+ piece = u[i:i+max_seg_chars].strip()
58
+ if piece and piece[-1] not in strong:
59
+ piece += ""
60
+ refined.append(piece)
61
+ # 如果最后一段没有标点,补句号
62
+ if refined and refined[-1][-1] not in strong:
63
+ refined[-1] += "。"
64
+ return [x for x in refined if x]
65
+
66
+ # ===== chunks 转 SRT (无编号 + 自动标点) =====
67
+ def chunks_to_srt(chunks: list[dict],
68
+ max_seg_dur: float = MAX_SEG_DUR,
69
+ max_seg_chars: int = MAX_SEG_CHARS) -> str:
70
+ lines = []
 
 
 
 
 
 
 
 
71
  for ch in chunks or []:
72
+ text = (ch.get("text") or "").strip()
73
+ if not text:
 
 
 
 
 
74
  continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ ts = ch.get("timestamp") or ch.get("timestamps") or [0.0, 2.0]
77
+ if isinstance(ts, (list, tuple)) and len(ts) == 2:
78
+ c_start, c_end = float(ts[0] or 0.0), float(ts[1] or 0.0)
79
+ else:
80
+ c_start, c_end = 0.0, 2.0
81
+
82
+ units = _split_text_units(text, max_seg_chars)
83
+ if not units:
84
+ units = [text]
85
+
86
+ total_chars = sum(len(u) for u in units) or 1
87
+ total_dur = max(c_end - c_start, 0.0)
88
+
89
+ cur_t = c_start
90
+ for u in units:
91
+ alloc = total_dur * (len(u) / total_chars)
92
+ alloc = max(alloc, 0.2)
93
+ if alloc <= max_seg_dur:
94
+ pieces = [u]
95
+ per = alloc
96
+ else:
97
+ # 再次切分,均匀分时长
98
+ smalls = [u[i:i+max_seg_chars] for i in range(0, len(u), max_seg_chars)]
99
+ pieces = [s.strip() + ("。" if not s.endswith("。") else "") for s in smalls if s.strip()]
100
+ per = min(max_seg_dur, alloc / max(1, len(pieces)))
101
+
102
+ for p in pieces:
103
+ if p and p[-1] not in "。!?.!?":
104
+ p += "。"
105
+ st = cur_t
106
+ en = st + per
107
+ lines.append(f"{_srt_timestamp(st)} --> {_srt_timestamp(en)}")
108
+ lines.append(p.strip())
109
+ lines.append("")
110
+ cur_t = en
111
 
 
 
 
 
 
 
 
112
  return "\n".join(lines).strip() + ("\n" if lines else "")
113
 
114
+ # ===== 上传音频 SRT =====
115
  @spaces.GPU
116
  def transcribe_file_to_srt(audio_path: str, task: str):
117
  if not audio_path:
 
127
  text = result.get("text", "") or ""
128
  chunks = result.get("chunks") or []
129
 
130
+ srt_str = chunks_to_srt(chunks)
131
+ if not srt_str and text.strip():
132
+ srt_str = "00:00:00,000 --> 00:00:02,000\n" + (text.strip() + "。") + "\n"
133
 
 
134
  tmpdir = tempfile.mkdtemp(prefix="srt_")
135
  base = os.path.splitext(os.path.basename(audio_path))[0] or "subtitle"
136
  srt_path = os.path.join(tmpdir, f"{base}.srt")
137
  with open(srt_path, "w", encoding="utf-8") as f:
138
  f.write(srt_str)
139
 
 
140
  return srt_str, srt_path
141
 
142
+ # ===== 界面 =====
143
  demo = gr.Interface(
144
  fn=transcribe_file_to_srt,
145
  inputs=[
 
151
  gr.File(label="Download SRT"),
152
  ],
153
  title="Upload Audio → SRT Subtitle",
154
+ description=f"Upload an audio file to generate time-stamped SRT subtitles (auto punctuation, no numbering). Model: {MODEL_NAME}",
 
 
 
155
  allow_flagging="never",
156
  )
157