Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Protocol, Tuple | |
| import gradio as gr | |
| import mmcv | |
| import numpy as np | |
| import torch | |
| from transformers import AutoImageProcessor, TrOCRProcessor, VisionEncoderDecoderModel | |
| from src.htr_pipeline.models import HtrModels | |
| from src.htr_pipeline.utils.filter_segmask import FilterSegMask | |
| from src.htr_pipeline.utils.helper import timer_func | |
| from src.htr_pipeline.utils.order_of_object import OrderObject | |
| from src.htr_pipeline.utils.preprocess_img import Preprocess | |
| from src.htr_pipeline.utils.process_segmask import SegMaskHelper | |
| class Inferencer: | |
| def __init__(self, local_run=False): | |
| htr_models = HtrModels(local_run) | |
| self.seg_model = htr_models.load_region_model() | |
| self.line_model = htr_models.load_line_model() | |
| self.htr_model_inferencer = htr_models.load_htr_model() | |
| self.process_seg_mask = SegMaskHelper() | |
| self.postprocess_seg_mask = FilterSegMask() | |
| self.ordering = OrderObject() | |
| self.preprocess_img = Preprocess() | |
| def predict_regions(self, input_image, pred_score_threshold=0.5, containments_threshold=0.5, visualize=True): | |
| input_image = self.preprocess_img.binarize_img(input_image) | |
| image = mmcv.imread(input_image) | |
| result = self.seg_model(image, return_datasample=True) | |
| result_pred = result["predictions"][0] | |
| filtered_result_pred = self.postprocess_seg_mask.filter_on_pred_threshold( | |
| result_pred, pred_score_threshold=pred_score_threshold | |
| ) | |
| if len(filtered_result_pred.pred_instances.masks) == 0: | |
| raise gr.Error("No Regions were predicted by the model") | |
| else: | |
| result_align = self.process_seg_mask.align_masks_with_image(filtered_result_pred, image) | |
| result_clean = self.postprocess_seg_mask.remove_overlapping_masks( | |
| predicted_mask=result_align, containments_threshold=containments_threshold | |
| ) | |
| if visualize: | |
| result_viz = self.seg_model.visualize( | |
| inputs=[image], preds=[result_clean], return_vis=True, no_save_vis=True | |
| )[0] | |
| else: | |
| result_viz = None | |
| regions_cropped, polygons = self.process_seg_mask.crop_masks(result_clean, image) | |
| order = self.ordering.order_regions_marginalia(result_clean) | |
| regions_cropped_ordered = [regions_cropped[i] for i in order] | |
| polygons_ordered = [polygons[i] for i in order] | |
| masks_ordered = [result_clean.pred_instances.masks[i] for i in order] | |
| return result_viz, regions_cropped_ordered, polygons_ordered, masks_ordered | |
| def predict_lines( | |
| self, | |
| image, | |
| pred_score_threshold=0.5, | |
| containments_threshold=0.5, | |
| line_spacing_factor=0.5, | |
| visualize=True, | |
| custom_track=True, | |
| ): | |
| result_tl = self.line_model(image, return_datasample=True) | |
| result_tl_pred = result_tl["predictions"][0] | |
| filtered_result_tl_pred = self.postprocess_seg_mask.filter_on_pred_threshold( | |
| result_tl_pred, pred_score_threshold=pred_score_threshold | |
| ) | |
| if len(filtered_result_tl_pred.pred_instances.masks) == 0 and custom_track: | |
| raise gr.Error("No Lines were predicted by the model") | |
| elif len(filtered_result_tl_pred.pred_instances.masks) == 0 and not custom_track: | |
| return None, None, None | |
| else: | |
| result_tl_align = self.process_seg_mask.align_masks_with_image(filtered_result_tl_pred, image) | |
| result_tl_clean = self.postprocess_seg_mask.remove_overlapping_masks( | |
| predicted_mask=result_tl_align, containments_threshold=containments_threshold | |
| ) | |
| if visualize: | |
| result_viz = self.seg_model.visualize( | |
| inputs=[image], | |
| preds=[result_tl_clean], | |
| return_vis=True, | |
| no_save_vis=True, | |
| )[0] | |
| else: | |
| result_viz = None | |
| lines_cropped, lines_polygons = self.process_seg_mask.crop_masks(result_tl_clean, image) | |
| ordered_indices = self.ordering.order_lines( | |
| line_image=result_tl_clean, line_spacing_factor=line_spacing_factor | |
| ) | |
| lines_cropped_ordered = [lines_cropped[i] for i in ordered_indices] | |
| lines_polygons_ordered = [lines_polygons[i] for i in ordered_indices] | |
| return result_viz, lines_cropped_ordered, lines_polygons_ordered | |
| def transcribe(self, line_cropped): | |
| result_rec = self.htr_model_inferencer(line_cropped) | |
| return result_rec["predictions"][0]["text"], round(result_rec["predictions"][0]["scores"], 4) | |
| def transcribe_different_model(self, image, htr_tool_transcriber_model_dropdown): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if htr_tool_transcriber_model_dropdown == "pstroe/bullinger-general-model": | |
| processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") | |
| image_processor = AutoImageProcessor.from_pretrained("pstroe/bullinger-general-model") | |
| model = VisionEncoderDecoderModel.from_pretrained("pstroe/bullinger-general-model") | |
| pixel_values = image_processor(image, return_tensors="pt").pixel_values.to(device) | |
| else: | |
| processor = TrOCRProcessor.from_pretrained(htr_tool_transcriber_model_dropdown) | |
| model = VisionEncoderDecoderModel.from_pretrained(htr_tool_transcriber_model_dropdown) | |
| pixel_values = processor(image, return_tensors="pt").pixel_values.to(device) | |
| model.to(device) | |
| generated_ids = model.generate(pixel_values) | |
| generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| return generated_text, 1.0 | |
| class InferencerInterface(Protocol): | |
| def predict_regions( | |
| self, | |
| image: np.array, | |
| pred_score_threshold: float, | |
| containments_threshold: float, | |
| visualize: bool = False, | |
| ) -> Tuple: | |
| ... | |
| def predict_lines( | |
| self, | |
| text_region: np.array, | |
| pred_score_threshold: float, | |
| containments_threshold: float, | |
| visualize: bool = False, | |
| custom_track: bool = False, | |
| ) -> Tuple: | |
| ... | |
| def transcribe( | |
| self, | |
| line: np.array, | |
| ) -> Tuple[str, float]: | |
| ... | |
| if __name__ == "__main__": | |
| prediction_model = Inferencer() | |