Heidel Medina. commited on
Commit
28808bd
·
1 Parent(s): e935b66

fixed a few minor issues in main.py

Browse files
Files changed (1) hide show
  1. main.py +22 -21
main.py CHANGED
@@ -8,41 +8,51 @@ from stable_whisper.text_output import result_to_any, sec2srt
8
  import tempfile
9
  import re
10
  import textwrap
 
11
 
 
12
  def process_media(
13
  model_size, source_lang, upload, model_type,
14
  max_chars, max_words, extend_in, extend_out, collapse_gaps,
15
  max_lines_per_segment, line_penalty, longest_line_char_penalty, *args
16
  ):
 
17
  if upload is None:
18
- return None, None, None, None, "No file uploaded."
19
 
20
  temp_path = upload.name
21
  base_path = os.path.splitext(temp_path)[0]
22
  word_transcription_path = base_path + '.json'
23
 
 
24
  if os.path.exists(word_transcription_path):
25
  print(f"Transcription data file found at {word_transcription_path}")
26
  result = stable_whisper.WhisperResult(word_transcription_path)
27
  else:
28
  print(f"Can't find transcription data file at {word_transcription_path}. Starting transcribing ...")
 
 
29
  if model_type == "faster whisper":
30
- model = stable_whisper.load_faster_whisper(model_size, device="cuda")
 
31
  else:
32
- model = stable_whisper.load_model(model_size, device="cuda")
 
 
33
  try:
34
  result = model.transcribe(temp_path, language=source_lang, vad=True, regroup=False, denoiser="demucs")
35
  except Exception as e:
36
- return None, None, None, None, f"Transcription failed: {e}"
37
  result.save_as_json(word_transcription_path)
38
 
 
39
  if max_chars or max_words:
40
  result.split_by_length(
41
  max_chars=int(max_chars) if max_chars else None,
42
  max_words=int(max_words) if max_words else None
43
  )
44
 
45
- # ----- Perform segment time extensions and anti-flickering (=closing the gaps) -----
46
  extend_start = float(extend_in) if extend_in else 0.0
47
  extend_end = float(extend_out) if extend_out else 0.0
48
  collapse_gaps_under = float(collapse_gaps) if collapse_gaps else 0.0
@@ -52,12 +62,10 @@ def process_media(
52
  next = result[i+1]
53
 
54
  if next.start - cur.end < extend_start + extend_end:
55
- # Not enough time to add the entire desired extensions -> add proportionally
56
  k = extend_end / (extend_start + extend_end) if (extend_start + extend_end) > 0 else 0
57
  mid = cur.end * (1 - k) + next.start * k
58
  cur.end = next.start = mid
59
  else:
60
- # Add full desired extensions
61
  cur.end += extend_end
62
  next.start -= extend_start
63
 
@@ -68,16 +76,11 @@ def process_media(
68
  result[0].start = max(0, result[0].start - extend_start)
69
  result[-1].end += extend_end
70
 
71
- #for seg in result:
72
- # seg.text = optimize_text(
73
- # seg.text,
74
- # int(max_lines_per_segment) if max_lines_per_segment else 3,
75
- # float(line_penalty) if line_penalty else 22.01,
76
- # float(longest_line_char_penalty) if longest_line_char_penalty else 1.0
77
- # )
78
 
79
- # Use custom SRT block output
80
- subtitles_path = tempfile.NamedTemporaryFile(delete=False, suffix=".srt", mode="w", encoding="utf-8").name
81
  result_to_any(
82
  result=result,
83
  filepath=subtitles_path,
@@ -91,23 +94,21 @@ def process_media(
91
  word_level=False,
92
  )
93
  srt_file_path = subtitles_path
94
-
95
  transcript_txt = result.to_txt()
96
 
97
  mime, _ = mimetypes.guess_type(temp_path)
98
  audio_out = temp_path if mime and mime.startswith("audio") else None
99
  video_out = temp_path if mime and mime.startswith("video") else None
100
 
101
- return audio_out, video_out, transcript_txt, srt_file_path, None
102
 
103
  def optimize_text(text, max_lines_per_segment, line_penalty, longest_line_char_penalty):
104
  text = text.strip()
105
  words = text.split()
106
 
107
- # Compute prefix sums
108
  psum = [0]
109
  for w in words:
110
- psum += [psum[-1] + len(w) + 1] # +1 because of spaces
111
 
112
  bestScore = 10 ** 30
113
  bestSplit = None
@@ -290,7 +291,7 @@ with gr.Blocks() as interface:
290
  source_lang = gr.Dropdown(
291
  choices=WHISPER_LANGUAGES,
292
  label="Source Language",
293
- value="en", # default to English
294
  interactive=True
295
  )
296
  model_type = gr.Dropdown(
 
8
  import tempfile
9
  import re
10
  import textwrap
11
+ import torch
12
 
13
+ # --- Main function to process the media file --- #
14
  def process_media(
15
  model_size, source_lang, upload, model_type,
16
  max_chars, max_words, extend_in, extend_out, collapse_gaps,
17
  max_lines_per_segment, line_penalty, longest_line_char_penalty, *args
18
  ):
19
+ # ----- is file empty? checker ----- #
20
  if upload is None:
21
+ return None, None, None, None
22
 
23
  temp_path = upload.name
24
  base_path = os.path.splitext(temp_path)[0]
25
  word_transcription_path = base_path + '.json'
26
 
27
+ # ---- Load .json or transcribe ---- #
28
  if os.path.exists(word_transcription_path):
29
  print(f"Transcription data file found at {word_transcription_path}")
30
  result = stable_whisper.WhisperResult(word_transcription_path)
31
  else:
32
  print(f"Can't find transcription data file at {word_transcription_path}. Starting transcribing ...")
33
+
34
+ #-- Check if CUDA is available or not --#
35
  if model_type == "faster whisper":
36
+ device = "cuda" if torch.cuda.is_available() else "cpu"
37
+ model = stable_whisper.load_faster_whisper(model_size, device=device)
38
  else:
39
+ device = "cuda" if torch.cuda.is_available() else "cpu"
40
+ model = stable_whisper.load_model(model_size, device=device)
41
+
42
  try:
43
  result = model.transcribe(temp_path, language=source_lang, vad=True, regroup=False, denoiser="demucs")
44
  except Exception as e:
45
+ return None, None, None, None # Remove the 5th value
46
  result.save_as_json(word_transcription_path)
47
 
48
+ # ADVANCED SETTINGS #
49
  if max_chars or max_words:
50
  result.split_by_length(
51
  max_chars=int(max_chars) if max_chars else None,
52
  max_words=int(max_words) if max_words else None
53
  )
54
 
55
+ # ----- Anti-flickering ----- #
56
  extend_start = float(extend_in) if extend_in else 0.0
57
  extend_end = float(extend_out) if extend_out else 0.0
58
  collapse_gaps_under = float(collapse_gaps) if collapse_gaps else 0.0
 
62
  next = result[i+1]
63
 
64
  if next.start - cur.end < extend_start + extend_end:
 
65
  k = extend_end / (extend_start + extend_end) if (extend_start + extend_end) > 0 else 0
66
  mid = cur.end * (1 - k) + next.start * k
67
  cur.end = next.start = mid
68
  else:
 
69
  cur.end += extend_end
70
  next.start -= extend_start
71
 
 
76
  result[0].start = max(0, result[0].start - extend_start)
77
  result[-1].end += extend_end
78
 
79
+ # --- Custom SRT block output --- #
80
+ original_filename = os.path.splitext(os.path.basename(temp_path))[0]
81
+ srt_dir = tempfile.gettempdir()
82
+ subtitles_path = os.path.join(srt_dir, f"{original_filename}.srt")
 
 
 
83
 
 
 
84
  result_to_any(
85
  result=result,
86
  filepath=subtitles_path,
 
94
  word_level=False,
95
  )
96
  srt_file_path = subtitles_path
 
97
  transcript_txt = result.to_txt()
98
 
99
  mime, _ = mimetypes.guess_type(temp_path)
100
  audio_out = temp_path if mime and mime.startswith("audio") else None
101
  video_out = temp_path if mime and mime.startswith("video") else None
102
 
103
+ return audio_out, video_out, transcript_txt, srt_file_path # Only 4 values
104
 
105
  def optimize_text(text, max_lines_per_segment, line_penalty, longest_line_char_penalty):
106
  text = text.strip()
107
  words = text.split()
108
 
 
109
  psum = [0]
110
  for w in words:
111
+ psum += [psum[-1] + len(w) + 1]
112
 
113
  bestScore = 10 ** 30
114
  bestSplit = None
 
291
  source_lang = gr.Dropdown(
292
  choices=WHISPER_LANGUAGES,
293
  label="Source Language",
294
+ value="tl", # default to Tagalog
295
  interactive=True
296
  )
297
  model_type = gr.Dropdown(