import gradio as gr import mimetypes import os os.environ['KMP_DUPLICATE_LIB_OK']='True' import argparse import stable_whisper from stable_whisper.text_output import result_to_any, sec2srt import tempfile import re import textwrap import torch def process_media( model_size, source_lang, upload, model_type, max_chars, max_words, extend_in, extend_out, collapse_gaps, max_lines_per_segment, line_penalty, longest_line_char_penalty, *args ): # ----- is file empty? checker ----- # if upload is None: return None, None, None, None temp_path = upload.name base_path = os.path.splitext(temp_path)[0] word_transcription_path = base_path + '.json' # ---- Load .json or transcribe ---- # if os.path.exists(word_transcription_path): print(f"Transcription data file found at {word_transcription_path}") result = stable_whisper.WhisperResult(word_transcription_path) else: print(f"Can't find transcription data file at {word_transcription_path}. Starting transcribing ...") #-- Check if CUDA is available or not --# if model_type == "faster whisper": device = "cuda" if torch.cuda.is_available() else "cpu" model = stable_whisper.load_faster_whisper(model_size, device=device) else: device = "cuda" if torch.cuda.is_available() else "cpu" model = stable_whisper.load_model(model_size, device=device) try: result = model.transcribe(temp_path, language=source_lang, vad=True, regroup=False, denoiser="demucs") except Exception as e: return None, None, None, None # Remove the 5th value result.save_as_json(word_transcription_path) # ADVANCED SETTINGS # if max_chars or max_words: result.split_by_length( max_chars=int(max_chars) if max_chars else None, max_words=int(max_words) if max_words else None ) # ----- Anti-flickering ----- # extend_start = float(extend_in) if extend_in else 0.0 extend_end = float(extend_out) if extend_out else 0.0 collapse_gaps_under = float(collapse_gaps) if collapse_gaps else 0.0 for i in range(len(result) - 1): cur = result[i] next = result[i+1] if next.start - cur.end < extend_start + extend_end: k = extend_end / (extend_start + extend_end) if (extend_start + extend_end) > 0 else 0 mid = cur.end * (1 - k) + next.start * k cur.end = next.start = mid else: cur.end += extend_end next.start -= extend_start if next.start - cur.end <= collapse_gaps_under: cur.end = next.start = (cur.end + next.start) / 2 if result: result[0].start = max(0, result[0].start - extend_start) result[-1].end += extend_end # --- Custom SRT block output --- # original_filename = os.path.splitext(os.path.basename(temp_path))[0] srt_dir = tempfile.gettempdir() subtitles_path = os.path.join(srt_dir, f"{original_filename}.srt") result_to_any( result=result, filepath=subtitles_path, filetype='srt', segments2blocks=lambda segments: segments2blocks( segments, int(max_lines_per_segment) if max_lines_per_segment else 3, float(line_penalty) if line_penalty else 22.01, float(longest_line_char_penalty) if longest_line_char_penalty else 1.0 ), word_level=False, ) srt_file_path = subtitles_path transcript_txt = result.to_txt() mime, _ = mimetypes.guess_type(temp_path) audio_out = temp_path if mime and mime.startswith("audio") else None video_out = temp_path if mime and mime.startswith("video") else None return audio_out, video_out, transcript_txt, srt_file_path # Only 4 values def optimize_text(text, max_lines_per_segment, line_penalty, longest_line_char_penalty): text = text.strip() words = text.split() psum = [0] for w in words: psum += [psum[-1] + len(w) + 1] bestScore = 10 ** 30 bestSplit = None def backtrack(level, wordsUsed, maxLineLength, split): nonlocal bestScore, bestSplit if wordsUsed == len(words): score = level * line_penalty + maxLineLength * longest_line_char_penalty if score < bestScore: bestScore = score bestSplit = split return if level + 1 == max_lines_per_segment: backtrack( level + 1, len(words), max(maxLineLength, psum[len(words)] - psum[wordsUsed] - 1), split + [words[wordsUsed:]] ) return for levelWords in range(1, len(words) - wordsUsed + 1): backtrack( level + 1, wordsUsed + levelWords, max(maxLineLength, psum[wordsUsed + levelWords] - psum[wordsUsed] - 1), split + [words[wordsUsed:wordsUsed + levelWords]] ) backtrack(0, 0, 0, []) optimized = '\n'.join(' '.join(words) for words in bestSplit) return optimized def segment2optimizedsrtblock(segment: dict, idx: int, max_lines_per_segment, line_penalty, longest_line_char_penalty, strip=True) -> str: return f'{idx}\n{sec2srt(segment["start"])} --> {sec2srt(segment["end"])}\n' \ f'{optimize_text(segment["text"], max_lines_per_segment, line_penalty, longest_line_char_penalty)}' def segments2blocks(segments, max_lines_per_segment, line_penalty, longest_line_char_penalty): return '\n\n'.join( segment2optimizedsrtblock(s, i, max_lines_per_segment, line_penalty, longest_line_char_penalty, strip=True) for i, s in enumerate(segments) ) WHISPER_LANGUAGES = [ ("Afrikaans", "af"), ("Albanian", "sq"), ("Amharic", "am"), ("Arabic", "ar"), ("Armenian", "hy"), ("Assamese", "as"), ("Azerbaijani", "az"), ("Bashkir", "ba"), ("Basque", "eu"), ("Belarusian", "be"), ("Bengali", "bn"), ("Bosnian", "bs"), ("Breton", "br"), ("Bulgarian", "bg"), ("Burmese", "my"), ("Catalan", "ca"), ("Chinese", "zh"), ("Croatian", "hr"), ("Czech", "cs"), ("Danish", "da"), ("Dutch", "nl"), ("English", "en"), ("Estonian", "et"), ("Faroese", "fo"), ("Finnish", "fi"), ("French", "fr"), ("Galician", "gl"), ("Georgian", "ka"), ("German", "de"), ("Greek", "el"), ("Gujarati", "gu"), ("Haitian Creole", "ht"), ("Hausa", "ha"), ("Hebrew", "he"), ("Hindi", "hi"), ("Hungarian", "hu"), ("Icelandic", "is"), ("Indonesian", "id"), ("Italian", "it"), ("Japanese", "ja"), ("Javanese", "jv"), ("Kannada", "kn"), ("Kazakh", "kk"), ("Khmer", "km"), ("Korean", "ko"), ("Lao", "lo"), ("Latin", "la"), ("Latvian", "lv"), ("Lingala", "ln"), ("Lithuanian", "lt"), ("Luxembourgish", "lb"), ("Macedonian", "mk"), ("Malagasy", "mg"), ("Malay", "ms"), ("Malayalam", "ml"), ("Maltese", "mt"), ("Maori", "mi"), ("Marathi", "mr"), ("Mongolian", "mn"), ("Nepali", "ne"), ("Norwegian", "no"), ("Nyanja", "ny"), ("Occitan", "oc"), ("Pashto", "ps"), ("Persian", "fa"), ("Polish", "pl"), ("Portuguese", "pt"), ("Punjabi", "pa"), ("Romanian", "ro"), ("Russian", "ru"), ("Sanskrit", "sa"), ("Serbian", "sr"), ("Shona", "sn"), ("Sindhi", "sd"), ("Sinhala", "si"), ("Slovak", "sk"), ("Slovenian", "sl"), ("Somali", "so"), ("Spanish", "es"), ("Sundanese", "su"), ("Swahili", "sw"), ("Swedish", "sv"), ("Tagalog", "tl"), ("Tajik", "tg"), ("Tamil", "ta"), ("Tatar", "tt"), ("Telugu", "te"), ("Thai", "th"), ("Turkish", "tr"), ("Turkmen", "tk"), ("Ukrainian", "uk"), ("Urdu", "ur"), ("Uzbek", "uz"), ("Vietnamese", "vi"), ("Welsh", "cy"), ("Yiddish", "yi"), ("Yoruba", "yo"), ] with gr.Blocks() as interface: gr.HTML( """