Spaces:
Configuration error
Configuration error
import gradio as gr | |
import pytesseract | |
import cv2 | |
import multiprocessing | |
from fuzzywuzzy import fuzz | |
from dataclasses import dataclass | |
from urllib.request import urlopen | |
import shutil | |
import pathlib | |
import datetime | |
import sys | |
# Constants | |
TESSDATA_DIR = pathlib.Path.home() / 'tessdata' | |
TESSDATA_URL = 'https://github.com/tesseract-ocr/tessdata_fast/raw/master/{}.traineddata' | |
TESSDATA_SCRIPT_URL = 'https://github.com/tesseract-ocr/tessdata_best/raw/master/script/{}.traineddata' | |
# Download language data files if necessary | |
def download_lang_data(lang: str): | |
TESSDATA_DIR.mkdir(parents=True, exist_ok=True) | |
for lang_name in lang.split('+'): | |
filepath = TESSDATA_DIR / f'{lang_name}.traineddata' | |
if not filepath.is_file(): | |
url = TESSDATA_SCRIPT_URL.format(lang_name) if lang_name[0].isupper() else TESSDATA_URL.format(lang_name) | |
with urlopen(url) as res, open(filepath, 'w+b') as f: | |
shutil.copyfileobj(res, f) | |
# Helper functions for time and frame conversion | |
def get_frame_index(time_str: str, fps: float): | |
t = list(map(float, time_str.split(':'))) | |
if len(t) == 3: | |
td = datetime.timedelta(hours=t[0], minutes=t[1], seconds=t[2]) | |
elif len(t) == 2: | |
td = datetime.timedelta(minutes=t[0], seconds=t[1]) | |
else: | |
raise ValueError(f'Time data "{time_str}" does not match format "%H:%M:%S"') | |
return int(td.total_seconds() * fps) | |
def get_srt_timestamp(frame_index: int, fps: float): | |
td = datetime.timedelta(seconds=frame_index / fps) | |
ms = td.microseconds // 1000 | |
m, s = divmod(td.seconds, 60) | |
h, m = divmod(m, 60) | |
return f'{h:02d}:{m:02d}:{s:02d},{ms:03d}' | |
# Video capture class using OpenCV | |
class Capture: | |
def __init__(self, video_path): | |
self.path = video_path | |
def __enter__(self): | |
self.cap = cv2.VideoCapture(self.path) | |
if not self.cap.isOpened(): | |
raise IOError(f'Cannot open video {self.path}.') | |
return self.cap | |
def __exit__(self, exc_type, exc_value, traceback): | |
self.cap.release() | |
class PredictedWord: | |
confidence: int | |
text: str | |
class PredictedFrame: | |
def __init__(self, index: int, pred_data: str, conf_threshold: int): | |
self.index = index | |
self.words = [] | |
block = 0 | |
for l in pred_data.splitlines()[1:]: | |
word_data = l.split() | |
if len(word_data) < 12: | |
continue | |
_, _, block_num, *_, conf, text = word_data | |
block_num, conf = int(block_num), int(conf) | |
if block < block_num: | |
block = block_num | |
if self.words and self.words[-1].text != '\n': | |
self.words.append(PredictedWord(0, '\n')) | |
if conf >= conf_threshold: | |
self.words.append(PredictedWord(conf, text)) | |
self.confidence = sum(word.confidence for word in self.words) | |
self.text = ' '.join(word.text for word in self.words).translate(str.maketrans('|', 'I', '<>{}[];`@#$%^*_=~\\')).replace(' \n ', '\n').strip() | |
def is_similar_to(self, other, threshold=70): | |
return fuzz.ratio(self.text, other.text) >= threshold | |
class PredictedSubtitle: | |
def __init__(self, frames, sim_threshold): | |
self.frames = [f for f in frames if f.confidence > 0] | |
self.sim_threshold = sim_threshold | |
self.text = max(self.frames, key=lambda f: f.confidence).text if self.frames else '' | |
def index_start(self): | |
return self.frames[0].index if self.frames else 0 | |
def index_end(self): | |
return self.frames[-1].index if self.frames else 0 | |
def is_similar_to(self, other): | |
return fuzz.partial_ratio(self.text, other.text) >= self.sim_threshold | |
class Video: | |
def __init__(self, path): | |
self.path = path | |
with Capture(path) as v: | |
self.num_frames = int(v.get(cv2.CAP_PROP_FRAME_COUNT)) | |
self.fps = v.get(cv2.CAP_PROP_FPS) | |
self.height = int(v.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
def run_ocr(self, lang, time_start, time_end, conf_threshold, use_fullframe): | |
self.lang = lang | |
self.use_fullframe = use_fullframe | |
ocr_start = get_frame_index(time_start, self.fps) if time_start else 0 | |
ocr_end = get_frame_index(time_end, self.fps) if time_end else self.num_frames | |
if ocr_end < ocr_start: | |
raise ValueError('time_start is later than time_end') | |
num_ocr_frames = ocr_end - ocr_start | |
with Capture(self.path) as v, multiprocessing.Pool() as pool: | |
v.set(cv2.CAP_PROP_POS_FRAMES, ocr_start) | |
frames = (v.read()[1] for _ in range(num_ocr_frames)) | |
it_ocr = pool.imap(self._image_to_data, frames, chunksize=10) | |
self.pred_frames = [PredictedFrame(i + ocr_start, data, conf_threshold) for i, data in enumerate(it_ocr)] | |
def _image_to_data(self, img): | |
if not self.use_fullframe: | |
img = img[self.height // 2:, :] | |
config = f'--tessdata-dir "{TESSDATA_DIR}"' | |
try: | |
return pytesseract.image_to_data(img, lang=self.lang, config=config) | |
except Exception as e: | |
sys.exit(f'{e.__class__.__name__}: {e}') | |
def get_subtitles(self, sim_threshold): | |
self._generate_subtitles(sim_threshold) | |
return ''.join(f'{i}\n{get_srt_timestamp(sub.index_start, self.fps)} --> {get_srt_timestamp(sub.index_end, self.fps)}\n{sub.text}\n\n' for i, sub in enumerate(self.pred_subs)) | |
def _generate_subtitles(self, sim_threshold): | |
self.pred_subs = [] | |
if self.pred_frames is None: | |
raise AttributeError('Please call self.run_ocr() first to perform OCR on frames') | |
WIN_BOUND = int(self.fps // 2) | |
bound = WIN_BOUND | |
i = 0 | |
j = 1 | |
while j < len(self.pred_frames): | |
fi, fj = self.pred_frames[i], self.pred_frames[j] | |
if fi.is_similar_to(fj): | |
bound = WIN_BOUND | |
elif bound > 0: | |
bound -= 1 | |
else: | |
para_new = j - WIN_BOUND | |
self._append_sub(PredictedSubtitle(self.pred_frames[i:para_new], sim_threshold)) | |
i = para_new | |
j = i | |
bound = WIN_BOUND | |
j += 1 | |
if i < len(self.pred_frames) - 1: | |
self._append_sub(PredictedSubtitle(self.pred_frames[i:], sim_threshold)) | |
def _append_sub(self, sub): | |
if not sub.text: | |
return | |
while self.pred_subs and sub.is_similar_to(self.pred_subs[-1]): | |
ls = self.pred_subs.pop() | |
sub = PredictedSubtitle(ls.frames + sub.frames, sub.sim_threshold) | |
self.pred_subs.append(sub) | |
# Gradio app | |
def extract_subtitles(video_file, lang, time_start, time_end, conf_threshold, use_fullframe, sim_threshold): | |
video = Video(video_file.name) | |
video.run_ocr(lang, time_start, time_end, conf_threshold, use_fullframe) | |
subtitles = video.get_subtitles(sim_threshold) | |
return subtitles | |
iface = gr.Interface( | |
fn=extract_subtitles, | |
inputs=[ | |
gr.Video(label="Video File"), | |
gr.Textbox(value='eng', label="OCR Language"), | |
gr.Textbox(value='00:00:00', label="Start Time (HH:MM:SS)"), | |
gr.Textbox(value='', label="End Time (HH:MM:SS, leave empty for full video)"), | |
gr.Slider(0, 100, value=60, step=1, label="Confidence Threshold"), | |
gr.Checkbox(label="Use Full Frame for OCR", default=False), | |
gr.Slider(0, 100, value=70, step=1, label="Similarity Threshold") | |
], | |
outputs=gr.Textbox(label="Extracted Subtitles"), | |
title="Video Subtitle Extractor", | |
description="Extract hardcoded subtitles from videos using machine learning.", | |
) | |
iface.launch() | |