Heidel Medina. commited on
Commit
e935b66
·
1 Parent(s): 3d86161

Added advanced settings feature to main.py

Browse files
Files changed (1) hide show
  1. main.py +135 -25
main.py CHANGED
@@ -9,35 +9,89 @@ import tempfile
9
  import re
10
  import textwrap
11
 
12
- def process_media(model_size, source_lang, upload, model_type):
 
 
 
 
13
  if upload is None:
14
  return None, None, None, None, "No file uploaded."
15
 
16
  temp_path = upload.name
 
 
17
 
18
- if model_type == "faster whisper":
19
- model = stable_whisper.load_faster_whisper(model_size, device="cuda")
 
20
  else:
21
- model = stable_whisper.load_model(model_size, device="cuda")
22
-
23
- try:
24
- result = model.transcribe(temp_path, language=source_lang, vad=False, regroup=False)
25
- except Exception as e:
26
- return None, None, None, None, f"Transcription failed: {e}"
27
-
28
- for i, segment in enumerate(result):
29
- if i+1 == len(result):
30
- break
31
- next_start = result[i+1].start
32
- if next_start - segment.end <= 0.100:
33
- segment.end = next_start
34
-
35
- srt_file = tempfile.NamedTemporaryFile(delete=False, suffix=".srt", mode="w", encoding="utf-8")
36
- result.to_srt_vtt(srt_file.name, word_level=False)
37
- srt_file.close()
38
- srt_file_path = srt_file.name
39
-
40
- # Transcript as plain text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  transcript_txt = result.to_txt()
42
 
43
  mime, _ = mimetypes.guess_type(temp_path)
@@ -45,7 +99,59 @@ def process_media(model_size, source_lang, upload, model_type):
45
  video_out = temp_path if mime and mime.startswith("video") else None
46
 
47
  return audio_out, video_out, transcript_txt, srt_file_path, None
48
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  WHISPER_LANGUAGES = [
50
  ("Afrikaans", "af"),
51
  ("Albanian", "sq"),
@@ -290,7 +396,11 @@ with gr.Blocks() as interface:
290
 
291
  submit_btn.click(
292
  fn=process_media,
293
- inputs=[model_size, source_lang, file_input, model_type],
 
 
 
 
294
  outputs=[audio_output, video_output, transcript_output, srt_output]
295
  )
296
 
 
9
  import re
10
  import textwrap
11
 
12
+ def process_media(
13
+ model_size, source_lang, upload, model_type,
14
+ max_chars, max_words, extend_in, extend_out, collapse_gaps,
15
+ max_lines_per_segment, line_penalty, longest_line_char_penalty, *args
16
+ ):
17
  if upload is None:
18
  return None, None, None, None, "No file uploaded."
19
 
20
  temp_path = upload.name
21
+ base_path = os.path.splitext(temp_path)[0]
22
+ word_transcription_path = base_path + '.json'
23
 
24
+ if os.path.exists(word_transcription_path):
25
+ print(f"Transcription data file found at {word_transcription_path}")
26
+ result = stable_whisper.WhisperResult(word_transcription_path)
27
  else:
28
+ print(f"Can't find transcription data file at {word_transcription_path}. Starting transcribing ...")
29
+ if model_type == "faster whisper":
30
+ model = stable_whisper.load_faster_whisper(model_size, device="cuda")
31
+ else:
32
+ model = stable_whisper.load_model(model_size, device="cuda")
33
+ try:
34
+ result = model.transcribe(temp_path, language=source_lang, vad=True, regroup=False, denoiser="demucs")
35
+ except Exception as e:
36
+ return None, None, None, None, f"Transcription failed: {e}"
37
+ result.save_as_json(word_transcription_path)
38
+
39
+ if max_chars or max_words:
40
+ result.split_by_length(
41
+ max_chars=int(max_chars) if max_chars else None,
42
+ max_words=int(max_words) if max_words else None
43
+ )
44
+
45
+ # ----- Perform segment time extensions and anti-flickering (=closing the gaps) -----
46
+ extend_start = float(extend_in) if extend_in else 0.0
47
+ extend_end = float(extend_out) if extend_out else 0.0
48
+ collapse_gaps_under = float(collapse_gaps) if collapse_gaps else 0.0
49
+
50
+ for i in range(len(result) - 1):
51
+ cur = result[i]
52
+ next = result[i+1]
53
+
54
+ if next.start - cur.end < extend_start + extend_end:
55
+ # Not enough time to add the entire desired extensions -> add proportionally
56
+ k = extend_end / (extend_start + extend_end) if (extend_start + extend_end) > 0 else 0
57
+ mid = cur.end * (1 - k) + next.start * k
58
+ cur.end = next.start = mid
59
+ else:
60
+ # Add full desired extensions
61
+ cur.end += extend_end
62
+ next.start -= extend_start
63
+
64
+ if next.start - cur.end <= collapse_gaps_under:
65
+ cur.end = next.start = (cur.end + next.start) / 2
66
+
67
+ if result:
68
+ result[0].start = max(0, result[0].start - extend_start)
69
+ result[-1].end += extend_end
70
+
71
+ #for seg in result:
72
+ # seg.text = optimize_text(
73
+ # seg.text,
74
+ # int(max_lines_per_segment) if max_lines_per_segment else 3,
75
+ # float(line_penalty) if line_penalty else 22.01,
76
+ # float(longest_line_char_penalty) if longest_line_char_penalty else 1.0
77
+ # )
78
+
79
+ # Use custom SRT block output
80
+ subtitles_path = tempfile.NamedTemporaryFile(delete=False, suffix=".srt", mode="w", encoding="utf-8").name
81
+ result_to_any(
82
+ result=result,
83
+ filepath=subtitles_path,
84
+ filetype='srt',
85
+ segments2blocks=lambda segments: segments2blocks(
86
+ segments,
87
+ int(max_lines_per_segment) if max_lines_per_segment else 3,
88
+ float(line_penalty) if line_penalty else 22.01,
89
+ float(longest_line_char_penalty) if longest_line_char_penalty else 1.0
90
+ ),
91
+ word_level=False,
92
+ )
93
+ srt_file_path = subtitles_path
94
+
95
  transcript_txt = result.to_txt()
96
 
97
  mime, _ = mimetypes.guess_type(temp_path)
 
99
  video_out = temp_path if mime and mime.startswith("video") else None
100
 
101
  return audio_out, video_out, transcript_txt, srt_file_path, None
102
+
103
+ def optimize_text(text, max_lines_per_segment, line_penalty, longest_line_char_penalty):
104
+ text = text.strip()
105
+ words = text.split()
106
+
107
+ # Compute prefix sums
108
+ psum = [0]
109
+ for w in words:
110
+ psum += [psum[-1] + len(w) + 1] # +1 because of spaces
111
+
112
+ bestScore = 10 ** 30
113
+ bestSplit = None
114
+
115
+ def backtrack(level, wordsUsed, maxLineLength, split):
116
+ nonlocal bestScore, bestSplit
117
+
118
+ if wordsUsed == len(words):
119
+ score = level * line_penalty + maxLineLength * longest_line_char_penalty
120
+ if score < bestScore:
121
+ bestScore = score
122
+ bestSplit = split
123
+ return
124
+
125
+ if level + 1 == max_lines_per_segment:
126
+ backtrack(
127
+ level + 1, len(words),
128
+ max(maxLineLength, psum[len(words)] - psum[wordsUsed] - 1),
129
+ split + [words[wordsUsed:]]
130
+ )
131
+ return
132
+
133
+ for levelWords in range(1, len(words) - wordsUsed + 1):
134
+ backtrack(
135
+ level + 1, wordsUsed + levelWords,
136
+ max(maxLineLength, psum[wordsUsed + levelWords] - psum[wordsUsed] - 1),
137
+ split + [words[wordsUsed:wordsUsed + levelWords]]
138
+ )
139
+
140
+ backtrack(0, 0, 0, [])
141
+
142
+ optimized = '\n'.join(' '.join(words) for words in bestSplit)
143
+ return optimized
144
+
145
+ def segment2optimizedsrtblock(segment: dict, idx: int, max_lines_per_segment, line_penalty, longest_line_char_penalty, strip=True) -> str:
146
+ return f'{idx}\n{sec2srt(segment["start"])} --> {sec2srt(segment["end"])}\n' \
147
+ f'{optimize_text(segment["text"], max_lines_per_segment, line_penalty, longest_line_char_penalty)}'
148
+
149
+ def segments2blocks(segments, max_lines_per_segment, line_penalty, longest_line_char_penalty):
150
+ return '\n\n'.join(
151
+ segment2optimizedsrtblock(s, i, max_lines_per_segment, line_penalty, longest_line_char_penalty, strip=True)
152
+ for i, s in enumerate(segments)
153
+ )
154
+
155
  WHISPER_LANGUAGES = [
156
  ("Afrikaans", "af"),
157
  ("Albanian", "sq"),
 
396
 
397
  submit_btn.click(
398
  fn=process_media,
399
+ inputs=[
400
+ model_size, source_lang, file_input, model_type,
401
+ max_chars, max_words, extend_in, extend_out, collapse_gaps,
402
+ max_lines_per_segment, line_penalty, longest_line_char_penalty
403
+ ],
404
  outputs=[audio_output, video_output, transcript_output, srt_output]
405
  )
406