Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	Synced repo using 'sync_with_huggingface' Github Action
Browse files- gradio_app.py +254 -0
- requirements.txt +4 -0
    	
        gradio_app.py
    ADDED
    
    | @@ -0,0 +1,254 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import sys
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            if "APP_PATH" in os.environ:
         | 
| 5 | 
            +
                os.chdir(os.environ["APP_PATH"])
         | 
| 6 | 
            +
                # fix sys.path for import
         | 
| 7 | 
            +
                sys.path.append(os.getcwd())
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from typing import List
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import gradio as gr
         | 
| 12 | 
            +
            import pypdfium2
         | 
| 13 | 
            +
            from pypdfium2 import PdfiumError
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from surya.detection import batch_text_detection
         | 
| 16 | 
            +
            from surya.input.pdflines import get_page_text_lines, get_table_blocks
         | 
| 17 | 
            +
            from surya.layout import batch_layout_detection
         | 
| 18 | 
            +
            from surya.model.detection.model import load_model, load_processor
         | 
| 19 | 
            +
            from surya.model.layout.model import load_model as load_layout_model
         | 
| 20 | 
            +
            from surya.model.layout.processor import load_processor as load_layout_processor
         | 
| 21 | 
            +
            from surya.model.recognition.model import load_model as load_rec_model
         | 
| 22 | 
            +
            from surya.model.recognition.processor import load_processor as load_rec_processor
         | 
| 23 | 
            +
            from surya.model.table_rec.model import load_model as load_table_model
         | 
| 24 | 
            +
            from surya.model.table_rec.processor import load_processor as load_table_processor
         | 
| 25 | 
            +
            from surya.postprocessing.heatmap import draw_polys_on_image, draw_bboxes_on_image
         | 
| 26 | 
            +
            from surya.ocr import run_ocr
         | 
| 27 | 
            +
            from surya.postprocessing.text import draw_text_on_image
         | 
| 28 | 
            +
            from PIL import Image
         | 
| 29 | 
            +
            from surya.languages import CODE_TO_LANGUAGE
         | 
| 30 | 
            +
            from surya.input.langs import replace_lang_with_code
         | 
| 31 | 
            +
            from surya.schema import OCRResult, TextDetectionResult, LayoutResult, TableResult
         | 
| 32 | 
            +
            from surya.settings import settings
         | 
| 33 | 
            +
            from surya.tables import batch_table_recognition
         | 
| 34 | 
            +
            from surya.postprocessing.util import rescale_bboxes, rescale_bbox
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            def load_det_cached():
         | 
| 38 | 
            +
                return load_model(), load_processor()
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            def load_rec_cached():
         | 
| 41 | 
            +
                return load_rec_model(), load_rec_processor()
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            def load_layout_cached():
         | 
| 44 | 
            +
                return load_layout_model(), load_layout_processor()
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            def load_table_cached():
         | 
| 47 | 
            +
                return load_table_model(), load_table_processor()
         | 
| 48 | 
            +
             | 
| 49 | 
            +
             | 
| 50 | 
            +
            def text_detection(img) -> (Image.Image, TextDetectionResult):
         | 
| 51 | 
            +
                pred = batch_text_detection([img], det_model, det_processor)[0]
         | 
| 52 | 
            +
                polygons = [p.polygon for p in pred.bboxes]
         | 
| 53 | 
            +
                det_img = draw_polys_on_image(polygons, img.copy())
         | 
| 54 | 
            +
                return det_img, pred
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            def layout_detection(img) -> (Image.Image, LayoutResult):
         | 
| 58 | 
            +
                pred = batch_layout_detection([img], layout_model, layout_processor)[0]
         | 
| 59 | 
            +
                polygons = [p.polygon for p in pred.bboxes]
         | 
| 60 | 
            +
                labels = [f"{p.label}-{p.position}" for p in pred.bboxes]
         | 
| 61 | 
            +
                layout_img = draw_polys_on_image(polygons, img.copy(), labels=labels, label_font_size=18)
         | 
| 62 | 
            +
                return layout_img, pred
         | 
| 63 | 
            +
             | 
| 64 | 
            +
             | 
| 65 | 
            +
            def table_recognition(img, highres_img, filepath, page_idx: int, use_pdf_boxes: bool, skip_table_detection: bool) -> (Image.Image, List[TableResult]):
         | 
| 66 | 
            +
                if skip_table_detection:
         | 
| 67 | 
            +
                    layout_tables = [(0, 0, highres_img.size[0], highres_img.size[1])]
         | 
| 68 | 
            +
                    table_imgs = [highres_img]
         | 
| 69 | 
            +
                else:
         | 
| 70 | 
            +
                    _, layout_pred = layout_detection(img)
         | 
| 71 | 
            +
                    layout_tables_lowres = [l.bbox for l in layout_pred.bboxes if l.label == "Table"]
         | 
| 72 | 
            +
                    table_imgs = []
         | 
| 73 | 
            +
                    layout_tables = []
         | 
| 74 | 
            +
                    for tb in layout_tables_lowres:
         | 
| 75 | 
            +
                        highres_bbox = rescale_bbox(tb, img.size, highres_img.size)
         | 
| 76 | 
            +
                        table_imgs.append(
         | 
| 77 | 
            +
                            highres_img.crop(highres_bbox)
         | 
| 78 | 
            +
                        )
         | 
| 79 | 
            +
                        layout_tables.append(highres_bbox)
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                try:
         | 
| 82 | 
            +
                    page_text = get_page_text_lines(filepath, [page_idx], [highres_img.size])[0]
         | 
| 83 | 
            +
                    table_bboxes = get_table_blocks(layout_tables, page_text, highres_img.size)
         | 
| 84 | 
            +
                except PdfiumError:
         | 
| 85 | 
            +
                    # This happens when we try to get text from an image
         | 
| 86 | 
            +
                    table_bboxes = [[] for _ in layout_tables]
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                if not use_pdf_boxes or any(len(tb) == 0 for tb in table_bboxes):
         | 
| 89 | 
            +
                    det_results = batch_text_detection(table_imgs, det_model, det_processor)
         | 
| 90 | 
            +
                    table_bboxes = [[{"bbox": tb.bbox, "text": None} for tb in det_result.bboxes] for det_result in det_results]
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                table_preds = batch_table_recognition(table_imgs, table_bboxes, table_model, table_processor)
         | 
| 93 | 
            +
                table_img = img.copy()
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                for results, table_bbox in zip(table_preds, layout_tables):
         | 
| 96 | 
            +
                    adjusted_bboxes = []
         | 
| 97 | 
            +
                    labels = []
         | 
| 98 | 
            +
                    colors = []
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    for item in results.rows + results.cols:
         | 
| 101 | 
            +
                        adjusted_bboxes.append([
         | 
| 102 | 
            +
                            (item.bbox[0] + table_bbox[0]),
         | 
| 103 | 
            +
                            (item.bbox[1] + table_bbox[1]),
         | 
| 104 | 
            +
                            (item.bbox[2] + table_bbox[0]),
         | 
| 105 | 
            +
                            (item.bbox[3] + table_bbox[1])
         | 
| 106 | 
            +
                        ])
         | 
| 107 | 
            +
                        labels.append(item.label)
         | 
| 108 | 
            +
                        if hasattr(item, "row_id"):
         | 
| 109 | 
            +
                            colors.append("blue")
         | 
| 110 | 
            +
                        else:
         | 
| 111 | 
            +
                            colors.append("red")
         | 
| 112 | 
            +
                    table_img = draw_bboxes_on_image(adjusted_bboxes, highres_img, labels=labels, label_font_size=18, color=colors)
         | 
| 113 | 
            +
                return table_img, table_preds
         | 
| 114 | 
            +
             | 
| 115 | 
            +
            def open_pdf(pdf_file):
         | 
| 116 | 
            +
                return pypdfium2.PdfDocument(pdf_file)
         | 
| 117 | 
            +
             | 
| 118 | 
            +
            def count_pdf(pdf_file):
         | 
| 119 | 
            +
                doc = open_pdf(pdf_file)
         | 
| 120 | 
            +
                return len(doc)
         | 
| 121 | 
            +
             | 
| 122 | 
            +
            def get_page_image(pdf_file, page_num, dpi=96):
         | 
| 123 | 
            +
                doc = open_pdf(pdf_file)
         | 
| 124 | 
            +
                renderer = doc.render(
         | 
| 125 | 
            +
                    pypdfium2.PdfBitmap.to_pil,
         | 
| 126 | 
            +
                    page_indices=[page_num - 1],
         | 
| 127 | 
            +
                    scale=dpi / 72,
         | 
| 128 | 
            +
                )
         | 
| 129 | 
            +
                png = list(renderer)[0]
         | 
| 130 | 
            +
                png_image = png.convert("RGB")
         | 
| 131 | 
            +
                return png_image
         | 
| 132 | 
            +
             | 
| 133 | 
            +
            def get_uploaded_image(in_file):
         | 
| 134 | 
            +
                return Image.open(in_file).convert("RGB")
         | 
| 135 | 
            +
             | 
| 136 | 
            +
            # Function for OCR
         | 
| 137 | 
            +
            def ocr(img, highres_img, langs: List[str]) -> (Image.Image, OCRResult):
         | 
| 138 | 
            +
                replace_lang_with_code(langs)
         | 
| 139 | 
            +
                img_pred = run_ocr([img], [langs], det_model, det_processor, rec_model, rec_processor, highres_images=[highres_img])[0]
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                bboxes = [l.bbox for l in img_pred.text_lines]
         | 
| 142 | 
            +
                text = [l.text for l in img_pred.text_lines]
         | 
| 143 | 
            +
                rec_img = draw_text_on_image(bboxes, text, img.size, langs, has_math="_math" in langs)
         | 
| 144 | 
            +
                return rec_img, img_pred
         | 
| 145 | 
            +
             | 
| 146 | 
            +
             | 
| 147 | 
            +
            det_model, det_processor = load_det_cached()
         | 
| 148 | 
            +
            rec_model, rec_processor = load_rec_cached()
         | 
| 149 | 
            +
            layout_model, layout_processor = load_layout_cached()
         | 
| 150 | 
            +
            table_model, table_processor = load_table_cached()
         | 
| 151 | 
            +
             | 
| 152 | 
            +
            with gr.Blocks(title="Surya") as demo:
         | 
| 153 | 
            +
                gr.Markdown("""
         | 
| 154 | 
            +
                # Surya OCR Demo
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                This app will let you try surya, a multilingual OCR model. It supports text detection + layout analysis in any language, and text recognition in 90+ languages.
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                Notes:
         | 
| 159 | 
            +
                - This works best on documents with printed text.
         | 
| 160 | 
            +
                - Preprocessing the image (e.g. increasing contrast) can improve results.
         | 
| 161 | 
            +
                - If OCR doesn't work, try changing the resolution of your image (increase if below 2048px width, otherwise decrease).
         | 
| 162 | 
            +
                - This supports 90+ languages, see [here](https://github.com/VikParuchuri/surya/tree/master/surya/languages.py) for a full list.
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                Find the project [here](https://github.com/VikParuchuri/surya).
         | 
| 165 | 
            +
                """)
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                with gr.Row():
         | 
| 168 | 
            +
                    with gr.Column():
         | 
| 169 | 
            +
                        in_file = gr.File(label="PDF file or image:", file_types=[".pdf", ".png", ".jpg", ".jpeg", ".gif", ".webp"])
         | 
| 170 | 
            +
                        in_num = gr.Slider(label="Page number", minimum=1, maximum=100, value=1, step=1)
         | 
| 171 | 
            +
                        in_img = gr.Image(label="Select page of Image", type="pil", sources=None)
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                        text_det_btn = gr.Button("Run Text Detection")
         | 
| 174 | 
            +
                        layout_det_btn = gr.Button("Run Layout Analysis")
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                        lang_dd = gr.Dropdown(label="Languages", choices=sorted(list(CODE_TO_LANGUAGE.values())), multiselect=True, max_choices=4, info="Select the languages in the image (if known) to improve OCR accuracy.  Optional.")
         | 
| 177 | 
            +
                        text_rec_btn = gr.Button("Run OCR")
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                        use_pdf_boxes_ckb = gr.Checkbox(label="Use PDF table boxes", value=True, info="Table recognition only: Use the bounding boxes from the PDF file vs text detection model.")
         | 
| 180 | 
            +
                        skip_table_detection_ckb = gr.Checkbox(label="Skip table detection", value=False, info="Table recognition only: Skip table detection and treat the whole image/page as a table.")
         | 
| 181 | 
            +
                        table_rec_btn = gr.Button("Run Table Rec")
         | 
| 182 | 
            +
                    with gr.Column():
         | 
| 183 | 
            +
                        result_img = gr.Image(label="Result image")
         | 
| 184 | 
            +
                        result_json = gr.JSON(label="Result json")
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                    def show_image(file, num=1):
         | 
| 187 | 
            +
                        if file.endswith('.pdf'):
         | 
| 188 | 
            +
                            count = count_pdf(file)
         | 
| 189 | 
            +
                            img = get_page_image(file, num)
         | 
| 190 | 
            +
                            return [
         | 
| 191 | 
            +
                                gr.update(visible=True, maximum=count),
         | 
| 192 | 
            +
                                gr.update(value=img)]
         | 
| 193 | 
            +
                        else:
         | 
| 194 | 
            +
                            img = get_uploaded_image(file)
         | 
| 195 | 
            +
                            return [
         | 
| 196 | 
            +
                                gr.update(visible=False),
         | 
| 197 | 
            +
                                gr.update(value=img)]
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    in_file.upload(
         | 
| 200 | 
            +
                        fn=show_image,
         | 
| 201 | 
            +
                        inputs=[in_file],
         | 
| 202 | 
            +
                        outputs=[in_num, in_img],
         | 
| 203 | 
            +
                    )
         | 
| 204 | 
            +
                    in_num.change(
         | 
| 205 | 
            +
                        fn=show_image,
         | 
| 206 | 
            +
                        inputs=[in_file, in_num],
         | 
| 207 | 
            +
                        outputs=[in_num, in_img],
         | 
| 208 | 
            +
                    )
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    # Run Text Detection
         | 
| 211 | 
            +
                    def text_det_img(pil_image):
         | 
| 212 | 
            +
                        det_img, pred = text_detection(pil_image)
         | 
| 213 | 
            +
                        return det_img, pred.model_dump(exclude=["heatmap", "affinity_map"])
         | 
| 214 | 
            +
                    text_det_btn.click(
         | 
| 215 | 
            +
                        fn=text_det_img,
         | 
| 216 | 
            +
                        inputs=[in_img],
         | 
| 217 | 
            +
                        outputs=[result_img, result_json]
         | 
| 218 | 
            +
                    )
         | 
| 219 | 
            +
                    # Run layout
         | 
| 220 | 
            +
                    def layout_det_img(pil_image):
         | 
| 221 | 
            +
                        layout_img, pred = layout_detection(pil_image)
         | 
| 222 | 
            +
                        return layout_img, pred.model_dump(exclude=["segmentation_map"])
         | 
| 223 | 
            +
                    layout_det_btn.click(
         | 
| 224 | 
            +
                        fn=layout_det_img,
         | 
| 225 | 
            +
                        inputs=[in_img],
         | 
| 226 | 
            +
                        outputs=[result_img, result_json]
         | 
| 227 | 
            +
                    )
         | 
| 228 | 
            +
                    # Run OCR
         | 
| 229 | 
            +
                    def text_rec_img(pil_image, in_file, page_number, languages):
         | 
| 230 | 
            +
                        if in_file.endswith('.pdf'):
         | 
| 231 | 
            +
                            pil_image_highres = get_page_image(in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES)
         | 
| 232 | 
            +
                        else:
         | 
| 233 | 
            +
                            pil_image_highres = pil_image
         | 
| 234 | 
            +
                        rec_img, pred = ocr(pil_image, pil_image_highres, languages)
         | 
| 235 | 
            +
                        return rec_img, pred.model_dump()
         | 
| 236 | 
            +
                    text_rec_btn.click(
         | 
| 237 | 
            +
                        fn=text_rec_img,
         | 
| 238 | 
            +
                        inputs=[in_img, in_file, in_num, lang_dd],
         | 
| 239 | 
            +
                        outputs=[result_img, result_json]
         | 
| 240 | 
            +
                    )
         | 
| 241 | 
            +
                    def table_rec_img(pil_image, in_file, page_number, use_pdf_boxes, skip_table_detection):
         | 
| 242 | 
            +
                        if in_file.endswith('.pdf'):
         | 
| 243 | 
            +
                            pil_image_highres = get_page_image(in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES)
         | 
| 244 | 
            +
                        else:
         | 
| 245 | 
            +
                            pil_image_highres = pil_image
         | 
| 246 | 
            +
                        table_img, pred = table_recognition(pil_image, pil_image_highres, in_file, page_number - 1 if page_number else None, use_pdf_boxes, skip_table_detection)
         | 
| 247 | 
            +
                        return table_img, [p.model_dump() for p in pred]
         | 
| 248 | 
            +
                    table_rec_btn.click(
         | 
| 249 | 
            +
                        fn=table_rec_img,
         | 
| 250 | 
            +
                        inputs=[in_img, in_file, in_num, use_pdf_boxes_ckb, skip_table_detection_ckb],
         | 
| 251 | 
            +
                        outputs=[result_img, result_json]
         | 
| 252 | 
            +
                    )
         | 
| 253 | 
            +
             | 
| 254 | 
            +
            demo.launch()
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,4 @@ | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            torch==2.5.1
         | 
| 2 | 
            +
            surya-ocr==0.7.0
         | 
| 3 | 
            +
            gradio==5.8.0
         | 
| 4 | 
            +
            huggingface-hub==0.26.3
         |