neko-image-gallery / app /Services /ocr_services.py
eggacheb's picture
Upload 97 files
21db53c verified
raw
history blame
4.87 kB
from time import time
import numpy as np
import torch
from PIL import Image
from loguru import logger
from app.Services.lifespan_service import LifespanService
from app.config import config
class OCRService(LifespanService):
def __init__(self):
self._device = config.device
if self._device == "auto":
self._device = "cuda" if torch.cuda.is_available() else "cpu"
@staticmethod
def _image_preprocess(img: Image.Image) -> Image.Image:
if img.mode != 'RGB':
img = img.convert('RGB')
if img.size[0] > 1024 or img.size[1] > 1024:
img.thumbnail((1024, 1024), Image.Resampling.LANCZOS)
new_img = Image.new('RGB', (1024, 1024), (0, 0, 0))
new_img.paste(img, ((1024 - img.size[0]) // 2, (1024 - img.size[1]) // 2))
return new_img
def ocr_interface(self, img: Image.Image, need_preprocess=True) -> str:
pass
class EasyPaddleOCRService(OCRService):
def __init__(self):
super().__init__()
from easypaddleocr import EasyPaddleOCR
self._paddle_ocr_module = EasyPaddleOCR(use_angle_cls=True,
needWarmUp=True,
devices=self._device,
warmup_size=(960, 960),
model_local_dir=config.model.easypaddleocr if
config.model.easypaddleocr else None)
logger.success("EasyPaddleOCR loaded successfully")
@staticmethod
def _image_preprocess(img: Image.Image) -> Image.Image:
# Optimized `easypaddleocr` doesn't require scaling preprocess
if img.mode != 'RGB':
img = img.convert('RGB')
return img
def _easy_paddleocr_process(self, img: Image.Image) -> str:
_, ocr_result, _ = self._paddle_ocr_module.ocr(np.array(img))
if ocr_result:
return "".join(itm[0] for itm in ocr_result if float(itm[1]) > config.ocr_search.ocr_min_confidence)
return ""
def ocr_interface(self, img: Image.Image, need_preprocess=True) -> str:
start_time = time()
logger.info("Processing text with EasyPaddleOCR...")
res = self._easy_paddleocr_process(self._image_preprocess(img) if need_preprocess else img)
logger.success("OCR processed done. Time elapsed: {:.2f}s", time() - start_time)
return res
class EasyOCRService(OCRService):
def __init__(self):
super().__init__()
# noinspection PyPackageRequirements
import easyocr # pylint: disable=import-error
self._easy_ocr_module = easyocr.Reader(config.ocr_search.ocr_language,
gpu=self._device == "cuda")
logger.success("easyOCR loaded successfully")
def _easyocr_process(self, img: Image.Image) -> str:
ocr_result = self._easy_ocr_module.readtext(np.array(img))
return " ".join(itm[1] for itm in ocr_result if itm[2] > config.ocr_search.ocr_min_confidence)
def ocr_interface(self, img: Image.Image, need_preprocess=True) -> str:
start_time = time()
logger.info("Processing text with easyOCR...")
res = self._easyocr_process(self._image_preprocess(img) if need_preprocess else img)
logger.success("OCR processed done. Time elapsed: {:.2f}s", time() - start_time)
return res
class PaddleOCRService(OCRService):
def __init__(self):
super().__init__()
# noinspection PyPackageRequirements
import paddleocr # pylint: disable=import-error
self._paddle_ocr_module = paddleocr.PaddleOCR(lang="ch", use_angle_cls=True,
use_gpu=self._device == "cuda")
logger.success("PaddleOCR loaded successfully")
def _paddleocr_process(self, img: Image.Image) -> str:
ocr_result = self._paddle_ocr_module.ocr(np.array(img), cls=True)
if ocr_result[0]:
return "".join(itm[1][0] for itm in ocr_result[0] if itm[1][1] > config.ocr_search.ocr_min_confidence)
return ""
def ocr_interface(self, img: Image.Image, need_preprocess=True) -> str:
start_time = time()
logger.info("Processing text with PaddleOCR...")
res = self._paddleocr_process(self._image_preprocess(img) if need_preprocess else img)
logger.success("OCR processed done. Time elapsed: {:.2f}s", time() - start_time)
return res
class DisabledOCRService(OCRService):
def __init__(self):
super().__init__()
logger.warning("OCR search is disabled. Skipping OCR model loading.")
def ocr_interface(self, img: Image.Image, need_preprocess=True) -> str:
raise NotImplementedError("OCR module is disabled. Consider enable it in config.")