Spaces:
Running
Running
| # -*- encoding: utf-8 -*- | |
| # @Author: SWHL | |
| # @Contact: [email protected] | |
| import time | |
| from pathlib import Path | |
| import cv2 | |
| import numpy as np | |
| import pandas as pd | |
| import streamlit as st | |
| from PIL import Image | |
| from rapidocr_onnxruntime import RapidOCR | |
| from streamlit_image_select import image_select | |
| from utils import visualize | |
| font_dict = { | |
| "ch": "chinese_cht.ttf", | |
| "japan": "japan.ttc", | |
| "korean": "korean.ttf", | |
| "en": "chinese_cht.ttf", | |
| } | |
| def init_sidebar(): | |
| st.session_state["params"] = {} | |
| st.sidebar.markdown( | |
| "### [🛠️ Parameter Settings](https://github.com/RapidAI/RapidOCR/wiki/config_parameter)" | |
| ) | |
| box_thresh = st.sidebar.slider( | |
| "box_thresh", | |
| min_value=0.0, | |
| max_value=1.0, | |
| value=0.5, | |
| step=0.1, | |
| help="检测到的框是文本的概率,值越大,框中是文本的概率就越大。存在漏检时,调低该值。取值范围:[0, 1.0],默认值为0.5", | |
| ) | |
| st.session_state["params"]["box_thresh"] = box_thresh | |
| unclip_ratio = st.sidebar.slider( | |
| "unclip_ratio", | |
| min_value=1.5, | |
| max_value=2.0, | |
| value=1.6, | |
| step=0.1, | |
| help="控制文本检测框的大小,值越大,检测框整体越大。在出现框截断文字的情况,调大该值。取值范围:[1.5, 2.0],默认值为1.6", | |
| ) | |
| st.session_state["params"]["unclip_ratio"] = unclip_ratio | |
| text_score = st.sidebar.slider( | |
| "text_score", | |
| min_value=0.0, | |
| max_value=1.0, | |
| value=0.5, | |
| step=0.1, | |
| help="文本识别结果是正确的置信度,值越大,显示出的识别结果更准确。存在漏检时,调低该值。取值范围:[0, 1.0],默认值为0.5", | |
| ) | |
| st.session_state["params"]["text_score"] = text_score | |
| img_file_buffer = st.sidebar.file_uploader( | |
| "Upload an image", | |
| accept_multiple_files=False, | |
| label_visibility="visible", | |
| type=["png", "jpg", "jpeg", "bmp"], | |
| ) | |
| if img_file_buffer: | |
| image = Image.open(img_file_buffer) | |
| img = np.array(image) | |
| with st.sidebar.container(): | |
| img_path = image_select( | |
| label="Examples(click to select):", | |
| images=examples, | |
| key="equation_default", | |
| use_container_width=True, | |
| ) | |
| img = cv2.imread(img_path) | |
| st.session_state["img"] = img | |
| def inference( | |
| text_det=None, | |
| text_rec=None, | |
| ): | |
| img = st.session_state.get("img") | |
| box_thresh = st.session_state["params"].get("box_thresh") | |
| unclip_ratio = st.session_state["params"].get("unclip_ratio") | |
| text_score = st.session_state["params"].get("text_score") | |
| det_model_path = str(Path("models") / "text_det" / text_det) | |
| rec_model_path = str(Path("models") / "text_rec" / text_rec) | |
| if "v2" in rec_model_path: | |
| rec_image_shape = [3, 32, 320] | |
| else: | |
| rec_image_shape = [3, 48, 320] | |
| rapid_ocr = RapidOCR( | |
| det_model_path=det_model_path, | |
| rec_model_path=rec_model_path, | |
| rec_img_shape=rec_image_shape, | |
| ) | |
| if "ch" in rec_model_path or "en" in rec_model_path: | |
| lan_name = "ch" | |
| elif "japan" in rec_model_path: | |
| lan_name = "japan" | |
| elif "korean" in rec_model_path: | |
| lan_name = "korean" | |
| else: | |
| lan_name = "ch" | |
| ocr_result, infer_elapse = rapid_ocr( | |
| img, box_thresh=box_thresh, unclip_ratio=unclip_ratio, text_score=text_score | |
| ) | |
| if not ocr_result or not infer_elapse: | |
| return None, None, None, None | |
| det_cost, cls_cost, rec_cost = infer_elapse | |
| elapse = f"- `det cost`: {det_cost:.5f}\n - `cls cost`: {cls_cost:.5f}\n - `rec cost`: {rec_cost:.5f}" | |
| dt_boxes, rec_res, scores = list(zip(*ocr_result)) | |
| font_path = Path("fonts") / font_dict.get(lan_name) | |
| vis_img = visualize( | |
| Image.fromarray(img), dt_boxes, rec_res, scores, font_path=str(font_path) | |
| ) | |
| out_df = pd.DataFrame( | |
| [[rec, score] for rec, score in zip(rec_res, scores)], | |
| columns=("Rec", "Score"), | |
| ) | |
| return vis_img, out_df, elapse, rec_res | |
| def tips(txt: str, wait_time: int = 2, icon: str = "🎉"): | |
| st.toast(txt, icon=icon) | |
| time.sleep(wait_time) | |
| if __name__ == "__main__": | |
| st.markdown( | |
| "<h1 style='text-align: center;'><a href='https://github.com/RapidAI/RapidOCR' style='text-decoration: none'>Rapid⚡OCR</a></h1>", | |
| unsafe_allow_html=True, | |
| ) | |
| st.markdown( | |
| """ | |
| <p align="left"> | |
| <a href=""><img src="https://img.shields.io/badge/Python->=3.6,<3.12-aff.svg"></a> | |
| <a href=""><img src="https://img.shields.io/badge/OS-Linux%2C%20Win%2C%20Mac-pink.svg"></a> | |
| <a href="https://pepy.tech/project/rapidocr_onnxruntime"><img src="https://static.pepy.tech/personalized-badge/rapidocr_onnxruntime?period=total&units=abbreviation&left_color=grey&right_color=blue&left_text=Downloads%20Ort"></a> | |
| <a href="https://pypi.org/project/rapidocr-onnxruntime/"><img alt="PyPI" src="https://img.shields.io/pypi/v/rapidocr-onnxruntime"></a> | |
| <a href='https://rapidocr.readthedocs.io/en/latest/?badge=latest'> | |
| <img src='https://readthedocs.org/projects/rapidocr/badge/?version=latest' alt='Documentation Status' /> | |
| </p> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| examples = [ | |
| "images/1.jpg", | |
| "images/ch_en_num.jpg", | |
| "images/air_ticket.jpg", | |
| "images/car_plate.jpeg", | |
| "images/train_ticket.jpeg", | |
| "images/japan_2.jpg", | |
| "images/korean_1.jpg", | |
| ] | |
| init_sidebar() | |
| menu_det, menu_rec = st.columns([1, 1]) | |
| det_models = [ | |
| "ch_PP-OCRv4_det_infer.onnx", | |
| "ch_PP-OCRv3_det_infer.onnx", | |
| "ch_PP-OCRv2_det_infer.onnx", | |
| "ch_ppocr_server_v2.0_det_infer.onnx", | |
| ] | |
| select_det = menu_det.selectbox("Det model:", det_models) | |
| rec_models = [ | |
| "ch_PP-OCRv4_rec_infer.onnx", | |
| "ch_PP-OCRv3_rec_infer.onnx", | |
| "ch_PP-OCRv2_rec_infer.onnx", | |
| "ch_ppocr_server_v2.0_rec_infer.onnx", | |
| "en_PP-OCRv3_rec_infer.onnx", | |
| "en_number_mobile_v2.0_rec_infer.onnx", | |
| "korean_mobile_v2.0_rec_infer.onnx", | |
| "japan_rec_crnn_v2.onnx", | |
| ] | |
| select_rec = menu_rec.selectbox("Rec model:", rec_models) | |
| out_img, out_json, elapse, only_txts = inference(select_det, select_rec) | |
| if all(v is not None for v in [out_img, out_json, elapse]): | |
| st.markdown("#### Visualize:") | |
| st.image(out_img) | |
| st.markdown("### Rec Result:") | |
| st.markdown(elapse) | |
| st.dataframe(out_json, use_container_width=True) | |
| st.markdown("### Only Txts") | |
| st.code(only_txts) | |
| else: | |
| tips("识别结果为空", wait_time=5, icon="⚠️") | |