import gradio as gr import pytesseract import cv2 import multiprocessing from fuzzywuzzy import fuzz from dataclasses import dataclass from urllib.request import urlopen import shutil import pathlib import datetime import sys # Constants TESSDATA_DIR = pathlib.Path.home() / 'tessdata' TESSDATA_URL = 'https://github.com/tesseract-ocr/tessdata_fast/raw/master/{}.traineddata' TESSDATA_SCRIPT_URL = 'https://github.com/tesseract-ocr/tessdata_best/raw/master/script/{}.traineddata' # Download language data files if necessary def download_lang_data(lang: str): TESSDATA_DIR.mkdir(parents=True, exist_ok=True) for lang_name in lang.split('+'): filepath = TESSDATA_DIR / f'{lang_name}.traineddata' if not filepath.is_file(): url = TESSDATA_SCRIPT_URL.format(lang_name) if lang_name[0].isupper() else TESSDATA_URL.format(lang_name) with urlopen(url) as res, open(filepath, 'w+b') as f: shutil.copyfileobj(res, f) # Helper functions for time and frame conversion def get_frame_index(time_str: str, fps: float): t = list(map(float, time_str.split(':'))) if len(t) == 3: td = datetime.timedelta(hours=t[0], minutes=t[1], seconds=t[2]) elif len(t) == 2: td = datetime.timedelta(minutes=t[0], seconds=t[1]) else: raise ValueError(f'Time data "{time_str}" does not match format "%H:%M:%S"') return int(td.total_seconds() * fps) def get_srt_timestamp(frame_index: int, fps: float): td = datetime.timedelta(seconds=frame_index / fps) ms = td.microseconds // 1000 m, s = divmod(td.seconds, 60) h, m = divmod(m, 60) return f'{h:02d}:{m:02d}:{s:02d},{ms:03d}' # Video capture class using OpenCV class Capture: def __init__(self, video_path): self.path = video_path def __enter__(self): self.cap = cv2.VideoCapture(self.path) if not self.cap.isOpened(): raise IOError(f'Cannot open video {self.path}.') return self.cap def __exit__(self, exc_type, exc_value, traceback): self.cap.release() @dataclass class PredictedWord: confidence: int text: str class PredictedFrame: def __init__(self, index: int, pred_data: str, conf_threshold: int): self.index = index self.words = [] block = 0 for l in pred_data.splitlines()[1:]: word_data = l.split() if len(word_data) < 12: continue _, _, block_num, *_, conf, text = word_data block_num, conf = int(block_num), int(conf) if block < block_num: block = block_num if self.words and self.words[-1].text != '\n': self.words.append(PredictedWord(0, '\n')) if conf >= conf_threshold: self.words.append(PredictedWord(conf, text)) self.confidence = sum(word.confidence for word in self.words) self.text = ' '.join(word.text for word in self.words).translate(str.maketrans('|', 'I', '<>{}[];`@#$%^*_=~\\')).replace(' \n ', '\n').strip() def is_similar_to(self, other, threshold=70): return fuzz.ratio(self.text, other.text) >= threshold class PredictedSubtitle: def __init__(self, frames, sim_threshold): self.frames = [f for f in frames if f.confidence > 0] self.sim_threshold = sim_threshold self.text = max(self.frames, key=lambda f: f.confidence).text if self.frames else '' @property def index_start(self): return self.frames[0].index if self.frames else 0 @property def index_end(self): return self.frames[-1].index if self.frames else 0 def is_similar_to(self, other): return fuzz.partial_ratio(self.text, other.text) >= self.sim_threshold class Video: def __init__(self, path): self.path = path with Capture(path) as v: self.num_frames = int(v.get(cv2.CAP_PROP_FRAME_COUNT)) self.fps = v.get(cv2.CAP_PROP_FPS) self.height = int(v.get(cv2.CAP_PROP_FRAME_HEIGHT)) def run_ocr(self, lang, time_start, time_end, conf_threshold, use_fullframe): self.lang = lang self.use_fullframe = use_fullframe ocr_start = get_frame_index(time_start, self.fps) if time_start else 0 ocr_end = get_frame_index(time_end, self.fps) if time_end else self.num_frames if ocr_end < ocr_start: raise ValueError('time_start is later than time_end') num_ocr_frames = ocr_end - ocr_start with Capture(self.path) as v, multiprocessing.Pool() as pool: v.set(cv2.CAP_PROP_POS_FRAMES, ocr_start) frames = (v.read()[1] for _ in range(num_ocr_frames)) it_ocr = pool.imap(self._image_to_data, frames, chunksize=10) self.pred_frames = [PredictedFrame(i + ocr_start, data, conf_threshold) for i, data in enumerate(it_ocr)] def _image_to_data(self, img): if not self.use_fullframe: img = img[self.height // 2:, :] config = f'--tessdata-dir "{TESSDATA_DIR}"' try: return pytesseract.image_to_data(img, lang=self.lang, config=config) except Exception as e: sys.exit(f'{e.__class__.__name__}: {e}') def get_subtitles(self, sim_threshold): self._generate_subtitles(sim_threshold) return ''.join(f'{i}\n{get_srt_timestamp(sub.index_start, self.fps)} --> {get_srt_timestamp(sub.index_end, self.fps)}\n{sub.text}\n\n' for i, sub in enumerate(self.pred_subs)) def _generate_subtitles(self, sim_threshold): self.pred_subs = [] if self.pred_frames is None: raise AttributeError('Please call self.run_ocr() first to perform OCR on frames') WIN_BOUND = int(self.fps // 2) bound = WIN_BOUND i = 0 j = 1 while j < len(self.pred_frames): fi, fj = self.pred_frames[i], self.pred_frames[j] if fi.is_similar_to(fj): bound = WIN_BOUND elif bound > 0: bound -= 1 else: para_new = j - WIN_BOUND self._append_sub(PredictedSubtitle(self.pred_frames[i:para_new], sim_threshold)) i = para_new j = i bound = WIN_BOUND j += 1 if i < len(self.pred_frames) - 1: self._append_sub(PredictedSubtitle(self.pred_frames[i:], sim_threshold)) def _append_sub(self, sub): if not sub.text: return while self.pred_subs and sub.is_similar_to(self.pred_subs[-1]): ls = self.pred_subs.pop() sub = PredictedSubtitle(ls.frames + sub.frames, sub.sim_threshold) self.pred_subs.append(sub) # Gradio app def extract_subtitles(video_file, lang, time_start, time_end, conf_threshold, use_fullframe, sim_threshold): video = Video(video_file.name) video.run_ocr(lang, time_start, time_end, conf_threshold, use_fullframe) subtitles = video.get_subtitles(sim_threshold) return subtitles iface = gr.Interface( fn=extract_subtitles, inputs=[ gr.Video(label="Video File"), gr.Textbox(value='eng', label="OCR Language"), gr.Textbox(value='00:00:00', label="Start Time (HH:MM:SS)"), gr.Textbox(value='', label="End Time (HH:MM:SS, leave empty for full video)"), gr.Slider(0, 100, value=60, step=1, label="Confidence Threshold"), gr.Checkbox(label="Use Full Frame for OCR", default=False), gr.Slider(0, 100, value=70, step=1, label="Similarity Threshold") ], outputs=gr.Textbox(label="Extracted Subtitles"), title="Video Subtitle Extractor", description="Extract hardcoded subtitles from videos using machine learning.", ) iface.launch()