videoocr2 / app.py
leetuan023's picture
Create app.py
7d8ee16 verified
raw
history blame
7.78 kB
import gradio as gr
import pytesseract
import cv2
import multiprocessing
from fuzzywuzzy import fuzz
from dataclasses import dataclass
from urllib.request import urlopen
import shutil
import pathlib
import datetime
import sys
# Constants
TESSDATA_DIR = pathlib.Path.home() / 'tessdata'
TESSDATA_URL = 'https://github.com/tesseract-ocr/tessdata_fast/raw/master/{}.traineddata'
TESSDATA_SCRIPT_URL = 'https://github.com/tesseract-ocr/tessdata_best/raw/master/script/{}.traineddata'
# Download language data files if necessary
def download_lang_data(lang: str):
TESSDATA_DIR.mkdir(parents=True, exist_ok=True)
for lang_name in lang.split('+'):
filepath = TESSDATA_DIR / f'{lang_name}.traineddata'
if not filepath.is_file():
url = TESSDATA_SCRIPT_URL.format(lang_name) if lang_name[0].isupper() else TESSDATA_URL.format(lang_name)
with urlopen(url) as res, open(filepath, 'w+b') as f:
shutil.copyfileobj(res, f)
# Helper functions for time and frame conversion
def get_frame_index(time_str: str, fps: float):
t = list(map(float, time_str.split(':')))
if len(t) == 3:
td = datetime.timedelta(hours=t[0], minutes=t[1], seconds=t[2])
elif len(t) == 2:
td = datetime.timedelta(minutes=t[0], seconds=t[1])
else:
raise ValueError(f'Time data "{time_str}" does not match format "%H:%M:%S"')
return int(td.total_seconds() * fps)
def get_srt_timestamp(frame_index: int, fps: float):
td = datetime.timedelta(seconds=frame_index / fps)
ms = td.microseconds // 1000
m, s = divmod(td.seconds, 60)
h, m = divmod(m, 60)
return f'{h:02d}:{m:02d}:{s:02d},{ms:03d}'
# Video capture class using OpenCV
class Capture:
def __init__(self, video_path):
self.path = video_path
def __enter__(self):
self.cap = cv2.VideoCapture(self.path)
if not self.cap.isOpened():
raise IOError(f'Cannot open video {self.path}.')
return self.cap
def __exit__(self, exc_type, exc_value, traceback):
self.cap.release()
@dataclass
class PredictedWord:
confidence: int
text: str
class PredictedFrame:
def __init__(self, index: int, pred_data: str, conf_threshold: int):
self.index = index
self.words = []
block = 0
for l in pred_data.splitlines()[1:]:
word_data = l.split()
if len(word_data) < 12:
continue
_, _, block_num, *_, conf, text = word_data
block_num, conf = int(block_num), int(conf)
if block < block_num:
block = block_num
if self.words and self.words[-1].text != '\n':
self.words.append(PredictedWord(0, '\n'))
if conf >= conf_threshold:
self.words.append(PredictedWord(conf, text))
self.confidence = sum(word.confidence for word in self.words)
self.text = ' '.join(word.text for word in self.words).translate(str.maketrans('|', 'I', '<>{}[];`@#$%^*_=~\\')).replace(' \n ', '\n').strip()
def is_similar_to(self, other, threshold=70):
return fuzz.ratio(self.text, other.text) >= threshold
class PredictedSubtitle:
def __init__(self, frames, sim_threshold):
self.frames = [f for f in frames if f.confidence > 0]
self.sim_threshold = sim_threshold
self.text = max(self.frames, key=lambda f: f.confidence).text if self.frames else ''
@property
def index_start(self):
return self.frames[0].index if self.frames else 0
@property
def index_end(self):
return self.frames[-1].index if self.frames else 0
def is_similar_to(self, other):
return fuzz.partial_ratio(self.text, other.text) >= self.sim_threshold
class Video:
def __init__(self, path):
self.path = path
with Capture(path) as v:
self.num_frames = int(v.get(cv2.CAP_PROP_FRAME_COUNT))
self.fps = v.get(cv2.CAP_PROP_FPS)
self.height = int(v.get(cv2.CAP_PROP_FRAME_HEIGHT))
def run_ocr(self, lang, time_start, time_end, conf_threshold, use_fullframe):
self.lang = lang
self.use_fullframe = use_fullframe
ocr_start = get_frame_index(time_start, self.fps) if time_start else 0
ocr_end = get_frame_index(time_end, self.fps) if time_end else self.num_frames
if ocr_end < ocr_start:
raise ValueError('time_start is later than time_end')
num_ocr_frames = ocr_end - ocr_start
with Capture(self.path) as v, multiprocessing.Pool() as pool:
v.set(cv2.CAP_PROP_POS_FRAMES, ocr_start)
frames = (v.read()[1] for _ in range(num_ocr_frames))
it_ocr = pool.imap(self._image_to_data, frames, chunksize=10)
self.pred_frames = [PredictedFrame(i + ocr_start, data, conf_threshold) for i, data in enumerate(it_ocr)]
def _image_to_data(self, img):
if not self.use_fullframe:
img = img[self.height // 2:, :]
config = f'--tessdata-dir "{TESSDATA_DIR}"'
try:
return pytesseract.image_to_data(img, lang=self.lang, config=config)
except Exception as e:
sys.exit(f'{e.__class__.__name__}: {e}')
def get_subtitles(self, sim_threshold):
self._generate_subtitles(sim_threshold)
return ''.join(f'{i}\n{get_srt_timestamp(sub.index_start, self.fps)} --> {get_srt_timestamp(sub.index_end, self.fps)}\n{sub.text}\n\n' for i, sub in enumerate(self.pred_subs))
def _generate_subtitles(self, sim_threshold):
self.pred_subs = []
if self.pred_frames is None:
raise AttributeError('Please call self.run_ocr() first to perform OCR on frames')
WIN_BOUND = int(self.fps // 2)
bound = WIN_BOUND
i = 0
j = 1
while j < len(self.pred_frames):
fi, fj = self.pred_frames[i], self.pred_frames[j]
if fi.is_similar_to(fj):
bound = WIN_BOUND
elif bound > 0:
bound -= 1
else:
para_new = j - WIN_BOUND
self._append_sub(PredictedSubtitle(self.pred_frames[i:para_new], sim_threshold))
i = para_new
j = i
bound = WIN_BOUND
j += 1
if i < len(self.pred_frames) - 1:
self._append_sub(PredictedSubtitle(self.pred_frames[i:], sim_threshold))
def _append_sub(self, sub):
if not sub.text:
return
while self.pred_subs and sub.is_similar_to(self.pred_subs[-1]):
ls = self.pred_subs.pop()
sub = PredictedSubtitle(ls.frames + sub.frames, sub.sim_threshold)
self.pred_subs.append(sub)
# Gradio app
def extract_subtitles(video_file, lang, time_start, time_end, conf_threshold, use_fullframe, sim_threshold):
video = Video(video_file.name)
video.run_ocr(lang, time_start, time_end, conf_threshold, use_fullframe)
subtitles = video.get_subtitles(sim_threshold)
return subtitles
iface = gr.Interface(
fn=extract_subtitles,
inputs=[
gr.Video(label="Video File"),
gr.Textbox(value='eng', label="OCR Language"),
gr.Textbox(value='00:00:00', label="Start Time (HH:MM:SS)"),
gr.Textbox(value='', label="End Time (HH:MM:SS, leave empty for full video)"),
gr.Slider(0, 100, value=60, step=1, label="Confidence Threshold"),
gr.Checkbox(label="Use Full Frame for OCR", default=False),
gr.Slider(0, 100, value=70, step=1, label="Similarity Threshold")
],
outputs=gr.Textbox(label="Extracted Subtitles"),
title="Video Subtitle Extractor",
description="Extract hardcoded subtitles from videos using machine learning.",
)
iface.launch()