Spaces:
Build error
Build error
| import os | |
| from abc import ABC | |
| from pathlib import Path | |
| from typing import Any, List, Literal, Mapping, Optional, Tuple | |
| from zipfile import ZipFile | |
| import json | |
| from typing import Any, List, Literal, Mapping, Optional,Dict | |
| import uuid | |
| from doctr.models.preprocessor import PreProcessor | |
| from doctr.models.recognition.predictor import RecognitionPredictor # pylint: disable=W0611 | |
| from doctr.models.recognition.zoo import ARCHS, recognition | |
| import torch | |
| # Numpy image type | |
| import numpy.typing as npt | |
| from numpy import uint8 | |
| ImageType = npt.NDArray[uint8] | |
| from utils import WordAnnotation,getlogger | |
| class DoctrTextRecognizer(): | |
| def __init__( | |
| self, | |
| architecture: str, | |
| path_weights: str, | |
| path_config_json: str = None, | |
| ) -> None: | |
| """ | |
| :param architecture: DocTR supports various text recognition models, e.g. "crnn_vgg16_bn", | |
| "crnn_mobilenet_v3_small". The full list can be found here: | |
| https://github.com/mindee/doctr/blob/main/doctr/models/recognition/zoo.py#L16. | |
| :param path_weights: Path to the weights of the model | |
| :param device: "cpu" or "cuda". | |
| :param lib: "TF" or "PT" or None. If None, env variables USE_TENSORFLOW, USE_PYTORCH will be used. | |
| :param path_config_json: Path to a json file containing the configuration of the model. Useful, if you have | |
| a model trained on custom vocab. | |
| """ | |
| self.architecture = architecture | |
| self.path_weights = path_weights | |
| self.name = self.get_name(self.path_weights, self.architecture) | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.path_config_json = path_config_json | |
| self.built_model = self.build_model(self.architecture, self.path_config_json) | |
| self.load_model(self.path_weights, self.built_model, self.device) | |
| self.doctr_predictor = self.get_wrapped_model() | |
| def predict(self, inputs: Dict[uuid.UUID, Tuple[ImageType,WordAnnotation]]) -> List[WordAnnotation]: | |
| """ | |
| Prediction on a batch of text lines | |
| :param images: Dictionary where key is word's object id and the value is tupe of cropped image and word annotation | |
| :return: A list of DetectionResult | |
| """ | |
| if inputs: | |
| predictor =self.doctr_predictor | |
| device = self.device | |
| word_uuids = list(inputs.keys()) | |
| cropped_images = [value[0] for value in inputs.values()] | |
| raw_output = predictor(list(cropped_images)) | |
| det_results =[] | |
| for uuid, output in zip(word_uuids, raw_output): | |
| ann = inputs[uuid][1] | |
| ann.text = output[0] | |
| det_results.append(ann) | |
| return det_results | |
| return [] | |
| def predict_for_tables(self, inputs: List[ImageType]) -> List[str]: | |
| if inputs: | |
| predictor =self.doctr_predictor | |
| device = self.device | |
| raw_output = predictor(list(inputs)) | |
| det_results =[] | |
| for output in raw_output: | |
| det_results.append(output[0]) | |
| return det_results | |
| return [] | |
| def load_model(path_weights: str, doctr_predictor: Any, device: torch.device) -> None: | |
| """Loading model weights | |
| 1. Load the State Dictionary: | |
| state_dict = torch.load(path_weights, map_location=device) loads the state dictionary from the specified file path and maps it to the specified device. | |
| 2. Modify Keys in the State Dictionary: | |
| The code prepends "model." to each key in the state dictionary. This is likely necessary to match the keys expected by the doctr_predictor model. | |
| 3. Load State Dictionary into Model: | |
| doctr_predictor.load_state_dict(state_dict) loads the modified state dictionary into the model. | |
| 4. Move Model to Device: | |
| doctr_predictor.to(device) moves the model to the specified device. | |
| """ | |
| state_dict = torch.load(path_weights, map_location=device) | |
| for key in list(state_dict.keys()): | |
| state_dict["model." + key] = state_dict.pop(key) | |
| doctr_predictor.load_state_dict(state_dict) | |
| doctr_predictor.to(device) | |
| def build_model(architecture: str, path_config_json: Optional[str] = None) -> "RecognitionPredictor": | |
| """Building the model | |
| 1. Specific keys (arch, url, task) are removed from custom_configs. | |
| mean and std values are moved to recognition_configs. | |
| 2. Creating model | |
| Check Architecture Type: | |
| Case 1 : | |
| If architecture is a string, it checks if it's in the predefined set of architectures (ARCHS). | |
| If valid, it creates an instance of the model using the specified architecture and custom configurations. | |
| Handle Custom Architecture Instances: | |
| Case 2 : | |
| If architecture is not a string, it checks if it's an **instance** of one of the recognized model classes (e.g., recognition.CRNN, recognition.SAR, etc.). | |
| If valid, it assigns the provided architecture to model. | |
| Get Input Shape and Create RecognitionPredictor: | |
| 3. Retrieves the input_shape from the model's configuration. | |
| 4. Returns an instance of RecognitionPredictor initialized with a PreProcessor and the model. | |
| """ | |
| # inspired and adapted from https://github.com/mindee/doctr/blob/main/doctr/models/recognition/zoo.py | |
| custom_configs = {} | |
| batch_size = 1024 | |
| recognition_configs = {} | |
| if path_config_json: | |
| with open(path_config_json, "r", encoding="utf-8") as f: | |
| custom_configs = json.load(f) | |
| custom_configs.pop("arch", None) | |
| custom_configs.pop("url", None) | |
| custom_configs.pop("task", None) | |
| recognition_configs["mean"] = custom_configs.pop("mean") | |
| recognition_configs["std"] = custom_configs.pop("std") | |
| #batch_size = custom_configs.pop("batch_size") | |
| recognition_configs["batch_size"] = batch_size | |
| if isinstance(architecture, str): | |
| if architecture not in ARCHS: | |
| raise ValueError(f"unknown architecture '{architecture}'") | |
| model = recognition.__dict__[architecture](pretrained=True, pretrained_backbone=True, **custom_configs) | |
| else: | |
| if not isinstance( | |
| architecture, | |
| (recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq), | |
| ): | |
| raise ValueError(f"unknown architecture: {type(architecture)}") | |
| model = architecture | |
| input_shape = model.cfg["input_shape"][-2:] | |
| """ | |
| (class) PreProcessor | |
| Implements an abstract preprocessor object which performs casting, resizing, batching and normalization. | |
| Args: | |
| output_size: expected size of each page in format (H, W) | |
| batch_size: the size of page batches | |
| mean: mean value of the training distribution by channel | |
| std: standard deviation of the training distribution by channel | |
| """ | |
| return RecognitionPredictor(PreProcessor(input_shape, preserve_aspect_ratio=True, **recognition_configs), model) | |
| def get_wrapped_model(self) -> Any: | |
| """ | |
| Get the inner (wrapped) model. | |
| """ | |
| doctr_predictor = self.build_model(self.architecture, self.path_config_json) | |
| device_str = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.load_model(self.path_weights, doctr_predictor, device_str) | |
| return doctr_predictor | |
| def get_name(path_weights: str, architecture: str) -> str: | |
| """Returns the name of the model""" | |
| return f"doctr_{architecture}" + "_".join(Path(path_weights).parts[-2:]) | |