Spaces:
Running
Running
import gradio as gr | |
import base64 | |
import io | |
import torch | |
from PIL import Image | |
from transformers import AutoModelForImageTextToText, AutoProcessor, AutoTokenizer | |
model_path = "nanonets/Nanonets-OCR-s" | |
# Load model once at startup | |
print("Loading Nanonets OCR model...") | |
model = AutoModelForImageTextToText.from_pretrained( | |
model_path, | |
torch_dtype="auto", | |
device_map="cpu", # 使用CPU | |
) | |
model.eval() | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
processor = AutoProcessor.from_pretrained(model_path) | |
print("Model loaded successfully!") | |
def ocr_image_gradio(image, max_tokens=4096): | |
"""Process image through Nanonets OCR model for Gradio interface""" | |
if image is None: | |
return "Please upload an image." | |
prompt = """Extract the text from the above document as if you were reading it naturally. Return the tables in html format. Return the equations in LaTeX representation. If there is an image in the document and image caption is not present, add a small description of the image inside the <img></img> tag; otherwise, add the image caption inside <img></img>. Watermarks should be wrapped in brackets. Ex: <watermark>OFFICIAL COPY</watermark>. Page numbers should be wrapped in brackets. Ex: <page_number>14</page_number> or <page_number>9/22</page_number>. Prefer using ☐ and ☑ for check boxes.""" | |
# Convert PIL image if needed | |
if not isinstance(image, Image.Image): | |
image = Image.fromarray(image) | |
messages = [ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image", "image": image}, | |
{"type": "text", "text": prompt}, | |
], | |
}, | |
] | |
text = processor.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) | |
inputs = processor(text=[text], images=[image], padding=True, return_tensors="pt") | |
inputs = inputs.to(model.device) | |
with torch.no_grad(): | |
output_ids = model.generate( | |
**inputs, | |
max_new_tokens=max_tokens, | |
do_sample=False, | |
repetition_penalty=1.25, | |
) | |
generated_ids = [ | |
output_ids[len(input_ids) :] | |
for input_ids, output_ids in zip(inputs.input_ids, output_ids) | |
] | |
output_text = processor.batch_decode( | |
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True | |
) | |
return output_text[0] | |
def ocr_base64_image(base64_string, max_tokens=4096): | |
"""Process base64 encoded image through Nanonets OCR model""" | |
if not base64_string or base64_string.strip() == "": | |
return "Please provide a valid base64 image string." | |
try: | |
# Remove data URL prefix if present | |
if "base64," in base64_string: | |
base64_string = base64_string.split("base64,")[1] | |
# Decode base64 to image | |
image_data = base64.b64decode(base64_string) | |
image = Image.open(io.BytesIO(image_data)) | |
# Process image using existing OCR function | |
return ocr_image_gradio(image, max_tokens) | |
except Exception as e: | |
return f"Error processing base64 image: {str(e)}" | |
# Create Gradio interface | |
with gr.Blocks(title="Nanonets OCR Demo") as demo: | |
# Replace simple markdown with styled HTML header that includes resources | |
gr.HTML(""" | |
<div class="title" style="text-align: center"> | |
<h1>🔍 Nanonets OCR - Document Text Extraction</h1> | |
<p style="font-size: 1.1em; color: #6b7280; margin-bottom: 0.6em;"> | |
A state-of-the-art image-to-markdown OCR model for intelligent document processing | |
</p> | |
<div style="display: flex; justify-content: center; gap: 20px; margin: 15px 0;"> | |
<a href="https://huggingface.co/nanonets/Nanonets-OCR-s" target="_blank" style="text-decoration: none; color: #2563eb; font-weight: 500;"> | |
📚 Hugging Face Model | |
</a> | |
<a href="https://nanonets.com/research/nanonets-ocr-s/" target="_blank" style="text-decoration: none; color: #2563eb; font-weight: 500;"> | |
📝 Release Blog | |
</a> | |
<a href="https://github.com/NanoNets/docext" target="_blank" style="text-decoration: none; color: #2563eb; font-weight: 500;"> | |
💻 GitHub Repository | |
</a> | |
</div> | |
</div> | |
""") | |
with gr.Tabs() as tabs: | |
# Image tab | |
with gr.TabItem("Image OCR"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image_input = gr.Image( | |
label="上传文档图片", type="pil", height=400 | |
) | |
image_max_tokens = gr.Slider( | |
minimum=1024, | |
maximum=8192, | |
value=4096, | |
step=512, | |
label="最大Token数", | |
info="生成的最大token数量", | |
) | |
image_extract_btn = gr.Button( | |
"提取文本", variant="primary", size="lg" | |
) | |
with gr.Column(scale=2): | |
image_output_text = gr.Textbox( | |
label="提取的文本", | |
lines=20, | |
show_copy_button=True, | |
placeholder="提取的文本将显示在这里...", | |
) | |
# Base64 Image tab | |
with gr.TabItem("Base64图片OCR"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
base64_input = gr.Textbox( | |
label="输入Base64编码的图片", | |
lines=10, | |
placeholder="粘贴Base64编码的图片数据...", | |
) | |
base64_max_tokens = gr.Slider( | |
minimum=1024, | |
maximum=8192, | |
value=4096, | |
step=512, | |
label="最大Token数", | |
info="生成的最大token数量", | |
) | |
base64_extract_btn = gr.Button( | |
"提取文本", variant="primary", size="lg" | |
) | |
with gr.Column(scale=2): | |
base64_output_text = gr.Textbox( | |
label="提取的文本", | |
lines=20, | |
show_copy_button=True, | |
placeholder="提取的文本将显示在这里...", | |
) | |
# Event handlers for Image tab | |
image_extract_btn.click( | |
fn=ocr_image_gradio, | |
inputs=[image_input, image_max_tokens], | |
outputs=image_output_text, | |
show_progress=True, | |
) | |
image_input.change( | |
fn=ocr_image_gradio, | |
inputs=[image_input, image_max_tokens], | |
outputs=image_output_text, | |
show_progress=True, | |
) | |
# Event handlers for Base64 tab | |
base64_extract_btn.click( | |
fn=ocr_base64_image, | |
inputs=[base64_input, base64_max_tokens], | |
outputs=base64_output_text, | |
show_progress=True, | |
) | |
# Add model information section | |
with gr.Accordion("关于 Nanonets-OCR-s", open=False): | |
gr.Markdown(""" | |
## Nanonets-OCR-s | |
Nanonets-OCR-s 是一个强大的最先进的图像到markdown的OCR模型,远超传统的文本提取功能。 | |
它将文档转换为带有智能内容识别和语义标记的结构化markdown,非常适合大型语言模型(LLM)的下游处理。 | |
### 主要特点 | |
- **LaTeX公式识别**:自动将数学公式转换为格式正确的LaTeX语法。 | |
它区分内联($...$)和显示($$...$$)公式。 | |
- **智能图像描述**:使用结构化的`<img>`标签描述文档中的图像,使它们易于LLM处理。 | |
它可以描述各种图像类型,包括徽标、图表、图形等,详细说明它们的内容、风格和上下文。 | |
- **签名检测与隔离**:识别并隔离签名与其他文本,将其输出在`<signature>`标签内。 | |
这对处理法律和商业文件至关重要。 | |
- **水印提取**:检测并提取文档中的水印文本,将其放在`<watermark>`标签内。 | |
- **智能复选框处理**:将表单复选框和单选按钮转换为标准化的Unicode符号(☐, ☑, ☒), | |
以实现一致可靠的处理。 | |
- **复杂表格提取**:准确地从文档中提取复杂表格,并将它们转换为markdown和HTML表格格式。 | |
""") | |
# API Usage Information | |
with gr.Accordion("API使用说明", open=True): | |
gr.Markdown(""" | |
## API使用方法 | |
### Base64图片识别API | |
您可以通过HTTP POST请求使用Base64图片识别API: | |
``` | |
curl -X POST "http://localhost:7860/api/predict" \\ | |
-H "Content-Type: application/json" \\ | |
-d '{ | |
"fn_index": 1, | |
"data": [ | |
"YOUR_BASE64_STRING_HERE", | |
4096 | |
] | |
}' | |
``` | |
- `fn_index: 1` 对应Base64图片OCR功能 | |
- 第一个参数是Base64编码的图片字符串 | |
- 第二个参数是最大token数量 | |
### 普通图片上传API | |
``` | |
curl -X POST "http://localhost:7860/api/predict" \\ | |
-H "Content-Type: application/json" \\ | |
-d '{ | |
"fn_index": 0, | |
"data": [ | |
"IMAGE_DATA_HERE", | |
4096 | |
] | |
}' | |
``` | |
- `fn_index: 0` 对应普通图片OCR功能 | |
""") | |
# CPU Usage Warning | |
with gr.Accordion("CPU环境说明", open=True): | |
gr.Markdown(""" | |
## CPU环境性能说明 | |
此应用程序当前运行在CPU环境下(2核16G),请注意: | |
- 处理大型图像可能需要更长时间 | |
- 建议使用较小的图像以获得更快的响应速度 | |
- 如果处理时间过长,可以考虑降低最大Token数 | |
- 模型已针对CPU环境进行了优化配置 | |
""") | |
if __name__ == "__main__": | |
import torch | |
print(f"使用设备: CPU - 可用线程数: {torch.get_num_threads()}") | |
# 设置线程数以优化CPU性能 | |
torch.set_num_threads(2) # 设置为可用的2核 | |
demo.queue().launch(share=True, server_name="0.0.0.0") |