Spaces:
Configuration error
Configuration error
File size: 7,779 Bytes
7d8ee16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
import gradio as gr
import pytesseract
import cv2
import multiprocessing
from fuzzywuzzy import fuzz
from dataclasses import dataclass
from urllib.request import urlopen
import shutil
import pathlib
import datetime
import sys
# Constants
TESSDATA_DIR = pathlib.Path.home() / 'tessdata'
TESSDATA_URL = 'https://github.com/tesseract-ocr/tessdata_fast/raw/master/{}.traineddata'
TESSDATA_SCRIPT_URL = 'https://github.com/tesseract-ocr/tessdata_best/raw/master/script/{}.traineddata'
# Download language data files if necessary
def download_lang_data(lang: str):
TESSDATA_DIR.mkdir(parents=True, exist_ok=True)
for lang_name in lang.split('+'):
filepath = TESSDATA_DIR / f'{lang_name}.traineddata'
if not filepath.is_file():
url = TESSDATA_SCRIPT_URL.format(lang_name) if lang_name[0].isupper() else TESSDATA_URL.format(lang_name)
with urlopen(url) as res, open(filepath, 'w+b') as f:
shutil.copyfileobj(res, f)
# Helper functions for time and frame conversion
def get_frame_index(time_str: str, fps: float):
t = list(map(float, time_str.split(':')))
if len(t) == 3:
td = datetime.timedelta(hours=t[0], minutes=t[1], seconds=t[2])
elif len(t) == 2:
td = datetime.timedelta(minutes=t[0], seconds=t[1])
else:
raise ValueError(f'Time data "{time_str}" does not match format "%H:%M:%S"')
return int(td.total_seconds() * fps)
def get_srt_timestamp(frame_index: int, fps: float):
td = datetime.timedelta(seconds=frame_index / fps)
ms = td.microseconds // 1000
m, s = divmod(td.seconds, 60)
h, m = divmod(m, 60)
return f'{h:02d}:{m:02d}:{s:02d},{ms:03d}'
# Video capture class using OpenCV
class Capture:
def __init__(self, video_path):
self.path = video_path
def __enter__(self):
self.cap = cv2.VideoCapture(self.path)
if not self.cap.isOpened():
raise IOError(f'Cannot open video {self.path}.')
return self.cap
def __exit__(self, exc_type, exc_value, traceback):
self.cap.release()
@dataclass
class PredictedWord:
confidence: int
text: str
class PredictedFrame:
def __init__(self, index: int, pred_data: str, conf_threshold: int):
self.index = index
self.words = []
block = 0
for l in pred_data.splitlines()[1:]:
word_data = l.split()
if len(word_data) < 12:
continue
_, _, block_num, *_, conf, text = word_data
block_num, conf = int(block_num), int(conf)
if block < block_num:
block = block_num
if self.words and self.words[-1].text != '\n':
self.words.append(PredictedWord(0, '\n'))
if conf >= conf_threshold:
self.words.append(PredictedWord(conf, text))
self.confidence = sum(word.confidence for word in self.words)
self.text = ' '.join(word.text for word in self.words).translate(str.maketrans('|', 'I', '<>{}[];`@#$%^*_=~\\')).replace(' \n ', '\n').strip()
def is_similar_to(self, other, threshold=70):
return fuzz.ratio(self.text, other.text) >= threshold
class PredictedSubtitle:
def __init__(self, frames, sim_threshold):
self.frames = [f for f in frames if f.confidence > 0]
self.sim_threshold = sim_threshold
self.text = max(self.frames, key=lambda f: f.confidence).text if self.frames else ''
@property
def index_start(self):
return self.frames[0].index if self.frames else 0
@property
def index_end(self):
return self.frames[-1].index if self.frames else 0
def is_similar_to(self, other):
return fuzz.partial_ratio(self.text, other.text) >= self.sim_threshold
class Video:
def __init__(self, path):
self.path = path
with Capture(path) as v:
self.num_frames = int(v.get(cv2.CAP_PROP_FRAME_COUNT))
self.fps = v.get(cv2.CAP_PROP_FPS)
self.height = int(v.get(cv2.CAP_PROP_FRAME_HEIGHT))
def run_ocr(self, lang, time_start, time_end, conf_threshold, use_fullframe):
self.lang = lang
self.use_fullframe = use_fullframe
ocr_start = get_frame_index(time_start, self.fps) if time_start else 0
ocr_end = get_frame_index(time_end, self.fps) if time_end else self.num_frames
if ocr_end < ocr_start:
raise ValueError('time_start is later than time_end')
num_ocr_frames = ocr_end - ocr_start
with Capture(self.path) as v, multiprocessing.Pool() as pool:
v.set(cv2.CAP_PROP_POS_FRAMES, ocr_start)
frames = (v.read()[1] for _ in range(num_ocr_frames))
it_ocr = pool.imap(self._image_to_data, frames, chunksize=10)
self.pred_frames = [PredictedFrame(i + ocr_start, data, conf_threshold) for i, data in enumerate(it_ocr)]
def _image_to_data(self, img):
if not self.use_fullframe:
img = img[self.height // 2:, :]
config = f'--tessdata-dir "{TESSDATA_DIR}"'
try:
return pytesseract.image_to_data(img, lang=self.lang, config=config)
except Exception as e:
sys.exit(f'{e.__class__.__name__}: {e}')
def get_subtitles(self, sim_threshold):
self._generate_subtitles(sim_threshold)
return ''.join(f'{i}\n{get_srt_timestamp(sub.index_start, self.fps)} --> {get_srt_timestamp(sub.index_end, self.fps)}\n{sub.text}\n\n' for i, sub in enumerate(self.pred_subs))
def _generate_subtitles(self, sim_threshold):
self.pred_subs = []
if self.pred_frames is None:
raise AttributeError('Please call self.run_ocr() first to perform OCR on frames')
WIN_BOUND = int(self.fps // 2)
bound = WIN_BOUND
i = 0
j = 1
while j < len(self.pred_frames):
fi, fj = self.pred_frames[i], self.pred_frames[j]
if fi.is_similar_to(fj):
bound = WIN_BOUND
elif bound > 0:
bound -= 1
else:
para_new = j - WIN_BOUND
self._append_sub(PredictedSubtitle(self.pred_frames[i:para_new], sim_threshold))
i = para_new
j = i
bound = WIN_BOUND
j += 1
if i < len(self.pred_frames) - 1:
self._append_sub(PredictedSubtitle(self.pred_frames[i:], sim_threshold))
def _append_sub(self, sub):
if not sub.text:
return
while self.pred_subs and sub.is_similar_to(self.pred_subs[-1]):
ls = self.pred_subs.pop()
sub = PredictedSubtitle(ls.frames + sub.frames, sub.sim_threshold)
self.pred_subs.append(sub)
# Gradio app
def extract_subtitles(video_file, lang, time_start, time_end, conf_threshold, use_fullframe, sim_threshold):
video = Video(video_file.name)
video.run_ocr(lang, time_start, time_end, conf_threshold, use_fullframe)
subtitles = video.get_subtitles(sim_threshold)
return subtitles
iface = gr.Interface(
fn=extract_subtitles,
inputs=[
gr.Video(label="Video File"),
gr.Textbox(value='eng', label="OCR Language"),
gr.Textbox(value='00:00:00', label="Start Time (HH:MM:SS)"),
gr.Textbox(value='', label="End Time (HH:MM:SS, leave empty for full video)"),
gr.Slider(0, 100, value=60, step=1, label="Confidence Threshold"),
gr.Checkbox(label="Use Full Frame for OCR", default=False),
gr.Slider(0, 100, value=70, step=1, label="Similarity Threshold")
],
outputs=gr.Textbox(label="Extracted Subtitles"),
title="Video Subtitle Extractor",
description="Extract hardcoded subtitles from videos using machine learning.",
)
iface.launch()
|